![](https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/1615/b65d45933899445cb7536b4feca0cbb3/c6fa2694-91b3-4f3e-ba53-1a05393ddf64.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-3.png?x-oss-process=image/resize,w_50,m_lfit)
蒙亦有道:用共形预测为机器学习赋予置信度
©️ Copyright 2023 @ Authors
作者:
Weiliang Luo 罗伟梁
日期:2023-11-09
共享协议:本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。
快速开始:点击上方的 开始连接 按钮,选择 bohrium-notebook:2023-04-07 镜像 和任意配置机型即可开始。
写在前面
这是一篇科普向 Notebook,其文字材料大量参考了文章 A Gentle Introduction to Conformal Prediction and Distribution-Free Uncertainty Quantification 及其作者 Stephen Bates 在 MIT 2023 年秋季学期讲授的 6.7900 Machine Learning 课程。
本文所需知识背景:
- 机器学习基础
- 概率论与数理统计
从二分类任务的概率诠释谈起
从线性 SVM 到最为复杂的神经网络,无论架构如何,很多二分类模型都会将最后一层的输出 通过 Logistic 函数得到最终结果 ,并在 时分类为 1,在 时分类为 0。直觉告诉我们,同样是分类为 1, 和 的“含金量”是不同的:如果这一数据点的真实分类 ,那么二分类任务常用的交叉熵损失函数 的值也在 时更小。似乎前者表明模型更加确信这一分类结果,而后者则是一个字面意思上“模棱两可”的输出。因此,我们可以将 解读为模型预测这一数据点分类为 1 的概率。
确实常常作为主动学习(active learning)中不确定性采样(uncertainty sampling)标准:在训练集中初训模型后,我们要从外部选择一批新的数据点进行标注并加入训练集继续训练。此时模型在外部数据上预测的 是否接近 0.5 可以作为其是否加入训练集的准绳——我们要挑选那些模型预测具有最小置信度(Least Confidence)的点来加强模型的认识。
在多分类任务中,Logistic 函数扩展为 Softmax 函数。如果一共有 个类别,则最后一层输出 对应的最终结果为 ,并将其分类为最大的 对应的类别 。由于 的良好性质,我们依然可以将 解读为模型所预测的数据点属于类别 的概率。在主动学习中,模型在外部数据上预测的最大和第二大的 之差是否接近 0 也可以作为是否加入训练集的准绳——我们要挑选那些模型预测区分度低的点来加强模型的认识。这称之为边缘(Margin)采样,可见最小置信度采样就是边缘采样在二分类问题中的特殊情况。
这些解读和应用看起来直观而美好,在二分类任务的理想情况下,我们希望对于任何数据点的预测结果 和真实标签 满足 ,但事实上仅通过常规的模型训练,以上概率恒等式没有任何理论保证。为了衡量我们预想的 的概率意义和实际的概率的差距,我们可以将区间 分割为互不相交的区间 ,并在 时将数据点 投入集合 中,然后计算校准误差(calibration error) 如果数据点足够多,其中的每一项都反映了 中从数据点估计的经验概率 和从模型中解读的预测概率 之间的平均差异。
借助这一指标,我们可以尝试通过调整模型输出来解决这一问题:我们将数据集的一部分划分为校准集(calibration set)。模型训练完成后,在推理所用的 Logistic 函数中加入超参数 ,使其变为 称为温度参数,较高的温度参数将使得 更接近 0 或 1,使模型更“自信”。于是,我们可以调节 使校准集上的校准误差尽可能地小,从而使模型预测的 更接近真实的概率意义。
借助可靠的概率输出,我们不仅能更有效地开展主动学习中的不确定度采样,还能够给出预测结果的置信度;当单一结果的置信度不高时,高置信度结果的集合也能为决策提供参考。然而,即使我们有充分大的校准集,也不能保证找到合适的 使校准误差小于我们所期待的值。换言之,即使通过校准和温度参数的手段,逻辑回归或 softmax 输出的依然是“伪概率”。
我们可以使用最简单的逻辑回归算法和经典的乳腺癌二分类数据集作为示例:首先将数据集分割为训练集和校准集
定义带正则化的逻辑回归的损失函数,其中 是模型权重
训练模型
校准集准确率:0.9578947368421052 训练集准确率:0.926056338028169
以下函数把 分为 num
等份来计算校准误差,并引入温度参数
作出不同温度参数下,模型预测的概率和校准集上估计的经验概率的折线图及其校准误差
![](https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/1615/74c37068741c4676bf69e03ddfb8b04d/8i9GEYTCFUBaZyxZVhRkZw.png)
可见适当调节 可以使得预测概率和经验概率相接近,但很难实现完美吻合。
通过共形预测得到给定置信度的结果范围
共形预测(conformal prediction)能够将上述这种并不严格但具有启发性的不确定度估计转化为统计上严格的不确定度估计,并适用于任何已经训练好的机器学习模型。在 AI4S 领域,它已经应用于毒性预测 (Zhang et. al., 2021) 和生物分子设计 (Fannjiang et. al., 2022) 等任务上。其基本步骤如下:
- 将数据分为训练集、校准集和测试集。在训练集上训练好一个模型,它根据数据特征 给出标签 的预测值 。
- 根据所选择的模型的意义,定义一个打分函数 :这个函数的值越大,表明模型对特征 的预测结果 和真实标签 的一致性越差。
- 对于给定的置信度 ,在整个校准集上计算 的 分位数 。
- 对于测试集中的新数据点 ,共形预测结果是一个集合 。
这个集合就是共形预测给出的结果范围,能够证明:
如果校准集和测试集中的所有数据 独立同分布,那么只要按照以上步骤 3 和 4 计算 并输出 ,就有 且在相当宽松的条件下,还有
这说明真实标签以近似为 的置信度落在集合 中。
我们会发现上述数学结论几乎完全不依赖于 的定义,千万不要被这样美好的结果所迷惑!事实上,共形预测方法保证了输出的结果范围具有“置信区间”的统计意义,但这个集合对决策是否有帮助仍要依赖于我们对 的定义—— 包含了我们对模型和任务的理解,是所得的结果范围提供有效信息的保证。我们来看一个没有任何信息量的共形预测的极端例子:定义 为一个随机数:
将之前的校准集重新分割为校准集和测试集
在不同的置信度下计算共形预测结果
将测试集上估计的经验概率 对置信度 作图
![](https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/1615/74c37068741c4676bf69e03ddfb8b04d/iDnMK_IBFp0iCPFN1WdU9w.png)
可见吻合良好。但实际上这一共形预测的输出结果只是按照给定的置信度随机猜测集合 ,与模型输出结果没有任何关系。我们可以统计一下 90% 置信度条件下输出结果中 的比例
0.8055555555555556
如此之高——这样的结果没有任何价值,因为只要输出足够多的 就能保证 大概率落入 。它确保了置信度,却未能对决策提供任何帮助。因此,我们仍需设计好的打分函数 ,在借助共形预测的统计严格性的同时,给出有意义的结果。
设计有意义的打分函数
为了利用我们训练的逻辑回归模型提供的信息,可以这样设计打分函数:设模型对于输入特征 的预测为 ,则 。这一打分函数非常直观地反映了 和 的偏差,我们来做以下实验:
![](https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/1615/74c37068741c4676bf69e03ddfb8b04d/gryvxyiI7fGBFRWeK2dSqg.png)
可见测试集上经验概率和置信度的吻合依然良好。接下来观察高置信度 90% 时共形预测结果有多少包含 。
0.0
远小于前面的随机打分函数!这说明共形预测确实从模型输出中受益:它能够更加精准地推断高置信度集合,而不是通过大量猜测全集 来提高对真实标签的覆盖率。
后记
关于共形预测,更多的代码案例可以参考文章 A Gentle Introduction to Conformal Prediction and Distribution-Free Uncertainty Quantification 的配套 Github 仓库中的示例 Notebook。
![](https://cdn1.deepmd.net/static/img/d7d9741bda38a158-957c-4877-942f-4bf6f81fcc63.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-1.png?x-oss-process=image/resize,w_50,m_lfit)
![](https://cdn1.deepmd.net/static/img/d7d9741bda38a158-957c-4877-942f-4bf6f81fcc63.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-1.png?x-oss-process=image/resize,w_50,m_lfit)
![](https://cdn1.deepmd.net/static/img/d7d9741bda38a158-957c-4877-942f-4bf6f81fcc63.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-2.png?x-oss-process=image/resize,w_50,m_lfit)
![](https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/19477/184f7dca7b5a440e81813b881a8faffd/9c283d69-ca7f-45df-8db3-f8705cf7db37.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-1.png?x-oss-process=image/resize,w_50,m_lfit)