与我交流
与我交流

# 3.2 非线性:分类模型

到目前为止,我们已经看到的两个示例都是回归任务,我们尝试预测数值(例如,下载时间或平均房价)。但是,机器学习中的另一个常见任务是分类。一些分类任务是二分类,目标是“是/否”问题的答案。科技界充满了这类问题。举几个例子:

  • 给定的电子邮件是否是垃圾邮件
  • 给定的信用卡交易是合法的还是欺诈的
  • 给定的一秒钟长的音频样本是否包含特定的口头单词
  • 两个指纹图像是否彼此匹配(即来自同一人的同一根手指) 分类问题的另一种类型是多分类任务,其示例也很多:
  • 新闻文章是否涉及体育,天气,游戏,政治或其他一般主题
  • 图片是否是猫,狗,铲子等
  • 从电子笔获得给定的笔划数据,确定是哪个字符
  • 在使用机器学习玩简单的类似 Atari(R)的视频游戏的场景中,在给定游戏当前状态的情况下,游戏角色应该在四个可能的方向(上,下,左和右)中的哪个方向继续前进。

# 3.2.1 什么是二分类

我们将从二分类的简单情况开始。给定一些数据,我们想要是/否的决定。这里有“网络钓鱼网站数据集” [55]。任务是:给定有关网页及其 URL 的特征集合,预测该网页是否用于网络钓鱼(即伪装成另一个旨在窃取用户敏感信息的网站)。
数据集包含 30 个要素,所有要素均为二进制(分别表示为值-1 和 1)或三元(分别表示为-1、0 和 1)。除了列出像波士顿住房数据集那样的所有单个特征外,我们在此提供一些代表性特征:

  • HAVING_IP_ADDRESS: IP 地址是否代替域名(二进制值:{1,-1})
  • SHORTENING_SERVICE:是否正在使用 URL 缩短服务(二进制值:{1,-1})
  • SSLFINAL_STATE:A:URL 使用 https 并且发行者受信任; B:URL 使用 https 但发行者不受信任; C:不使用 https (分别是:{-1,0,1})。 该数据集包括大约 5500 个训练示例和相等数量的测试示例。在训练集中,大约有 45%的示例是肯定的(即真正的网络钓鱼网页)。阳性样本的百分比在测试集中大致相同。 这是最简单的数据集类型,即数据中的要素已经在一致的范围内,因此无需像我们对 Boston Housing 数据集所做的那样标准化其均值和标准差。如果我们想花更多的时间研究数据,则可以进行成对特征相关性检查,以了解我们是否有多余的信息。 由于我们的数据看起来与我们在 Boston Housing 中使用的数据(归一化后的数据)相似,因此我们的初始模型使用相同的结构。在 tfjs -examples 库的网站钓鱼文件夹下,可以找到针对此问题的示例代码。您可以通过以下方式查看并运行示例:
git clone https://github.com/tensorflow/tfjs-examples.git
cd tfjs-examples/website-phishing
yarn && yarn watch
# 清单 3.4 定义用于钓鱼攻击的二分类模型
const model = tf.sequential();
model.add(
  tf.layers.dense({
    inputShape: [data.numFeatures],
    units: 100,
    activation: 'sigmoid'
  })
);
model.add(tf.layers.dense({ units: 100, activation: 'sigmoid' }));
model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));
model.compile({
  optimizer: 'adam',
  loss: 'binaryCrossentropy',
  metrics: ['accuracy']
});

该模型与我们为波士顿住房问题构建的多层网络有很多相似之处。它从两个隐藏层开始,并且两个都使用 sigmoid 激活。最后一个(输出)的单位正好为 1,这意味着模型为每个输入示例输出一个数字。但是,此处的主要区别在于,我们的网络钓鱼检测模型的最后一层具有 sigmoid 激活,而不是像 Boston Housing 模型中的默认线性激活。这意味着我们的模型只能输出 0 到 1 之间的数字,这与 Boston-Housing 模型可能输出任何浮点数不同。 以前,我们已经看到隐藏层的 sigmoid 激活可以帮助增加模型容量。但是,为什么在这个新模型的输出上使用 sigmoid 激活呢?这与我们手头问题的二进制分类性质有关。对于二元分类,我们通常希望模型产生对阳性类别概率的猜测,即模型“认为”给定示例属于阳性类别的可能性。您可能从中学数学中就会学到,概率始终是 0 到 1 之间的数字。通过使模型始终输出估计的概率值,我们可以获得:

  • 对指定分类的支持程度。S 值为 0.5 表示完全不确定,其中任何一种分类都得到同等支持。值 0.6 表示虽然系统预测的是阳性分类,但仅提供了很少的支持。值为 0.99 意味着模型可以确定该示例属于阳性类,依此类推。因此,我们可以轻松而直接地将模型的输出转换为最终答案(例如,仅将输出阈值设置为给定值,例如 0.5)。现在想象一下,如果模型输出的范围可能在很大范围内变化,那么找到这样的阈值将是多么困难。
  • 我们还可以更轻松地得到可微分损失函数,该函数在给出模型的输出和二分类标签后,产生一个数字,该数字可以衡量多少模型未被标记。对于后一点,当该模型使用二进制交叉熵时,我们在详细说明。 但是,问题是如何保证神经网络的输出在[0,1]范围。神经网络的最后一层通常是一个密集层,它的输入执行矩阵乘法(matMul )和偏差加法(biasAdd )运算。matMul 或 biasAdd 对保证结果在[0,1]范围都没有直接相关的内在联系。在 matMul 和 biasAdd 的中添加诸如 sigmoid 的压缩非线性系数才是实现结果在[0,1]范围的通用方法。
    清单 3.4 中提供的代码中的另一种优化器的类型:' adam ',它不同于先前示例中使用的随机梯度下降(' sgd ')的优化器。' adam '和 ' sgd '有何不同?在上一章的第 2.2.2 节中,sgd 优化器始终将通过反向传播获得的梯度乘以固定的数字(即其学习速率乘以-1),来计算更新模型权重。这种方法有一些缺点,当选择较小学习率时缓慢地趋向于损耗最小值,而当损耗(超)表面的形状具有某些特殊属性时,权重空间呈现“锯齿形”路径。' adam '优化器是从起初的训练迭代开始,乘法因子便沿着一定的梯度变化。另外,不同权重参数使用不同的乘法因子。在各种深度学习模型类型上,与 sgd 相比,adam 通常会得到更好的收敛性和对学习率具有选择依赖性,因此,它是优化器的流行选择。TensorFlow.js 库提供了许多其他优化器类型,其中一些也很流行(例如 rmsprop )。信息框 3.1 中的表格对它们进行了简要概述。

信息框 3.1

TensorFlow.js 支持的优化器 下表总结了 TensorFlow.js 中最常用的优化器类型的 API,并为每个优化器提供了简单直观的说明。

名称 API(字符串) API 描述
随机梯度下降(SGD) 'sgd' tf.train.sgd 最简单的优化程序,始终将学习率用作梯度的乘数
动量 'momentum' tf.train .momentum 以某种方式累积过去的梯度,以使权重参数的更新在相同方向上排列更多时会更快,而在方向上变化很大时会变得更慢。
MSProp 'rmsprop' tf.train .rmsprop 通过跟踪每个权重梯度的均方根(RMS)值的最新历史记录,针对模型的不同权重参数对缩放因子进行不同缩放。
AdaDelta 'adadelta' tf.train .adadelta 类似于 RMSProp 的方式缩放每个权重参数的学习率
adam ‘adam’ tf.train .adam 可以理解为 AdaDelta 的自适应学习率方法和动量方法的结合
AdaMax 'adamax' tf.train .adamax 与 ADAM 类似,使用略有不同的算法来跟踪梯度的大小

明显的问题是,鉴于正在研究的机器学习问题和模型,应该使用哪个优化程序。这在深度学习领域尚无共识(这就是 TensorFlow.js 提供了上表列出的所有优化器的原因!)实际上,您应该从流行的优化器开始,比如 adam 和 rmsprop 。如果有足够的时间和计算资源,您也可以将优化器视为超参数,并通过超参数调整来找到能够提供最佳训练结果的选择(请参阅第 3.1.2 节)。

# 3.2.2 评估二元分类器:精度,召回率,准确性和 ROC 曲线

在二分类问题中,我们得到两个值之一比如 0 或 1,是或否等。从更抽象的意义上讲,我们将讨论“正值”和“ 负值” 。当我们的模型进行猜测时,它是对还是错,对于输入示例的实际标签和模型的输出,我们有四种可能的情况,如下表 3.1 所示。

# 表 3.1 四种类型的分类会带来的二进制分类问题

 

模型预测

真相

正确肯定 (TP)

错误否定 (FN)

错误肯定 (FP)

正确否定 (TN)

模型预测正确答案的是正确肯定(TP)和正确否定(TN)。错误肯定(FP)和错误否定(FN)是模型弄错的地方。如果我们用计数填充四个单元,则会得到一个混淆矩阵,有关我们的网络钓鱼检测问题的假想矩阵,请参见下文。 ###### 表3.2二分类问题的结果分布

 

模型预测

真相

4

2

1

93

从网络钓鱼示例的假设结果中,我们看到我们正确识别了4个网络钓鱼网页,错误识别2个网页,并发出了一个错误警报。现在,让我们看一下表达性能的不同通用指标。 准确性是最简单的指标。它量化了正确分类的示例百分比:

Accuracy =(#TP + #TN)/ #examples =(#TP + #TN)/(#TP + #TN + #FP + #FN)

在我们的特定示例中,

Accuracy =(4 + 93)/ 100 = 97%

准确性是一个易于传达和理解的概念。但是,这可能会引起误解-通常在二分类任务中,我们没有正例和负例的均等分布。通常,我们遇到的正面示例要比负面示例少得多(例如,大多数链接不是网络钓鱼,大多数部分都没有缺陷,等等)。如果每 100 个链接中只有 5 个是网络钓鱼,则我们的网络全部预测错误,便也获得 95%的准确性!对于我们的系统而言,准确性似乎是非常糟糕的衡量标准。高准确度听起来总是不错,但常常会引起误解。用作损失函数将是一件不太明智的选择。
下一对度量标准试图捕获准确性缺失的细微之处- 精度和召回率。在随后的讨论中,我们通常还会考虑一些问题,其中肯定(标记为突出显示的链接)表示需要采取进一步的措施,而否定(需手动核实的)表示现状。这些指标侧重于预测可能出现的不同类型的“错误”。
Precision 是模型做出的正确积极预测占全部积极预测的比率。

Precision = #TP /(#TP + #FP)

根据表中的数字,我们可以计算出:

Precision = 4 /(4 + 1)= 80%

您可以使模型在发出肯定的预测时非常保守,例如,仅将具有非常高的 sigmoid 输出(例如> 0.95,而不是默认的> 0.5)标记为正的输入示例。这通常会导致精度提高,但是这样做可能会导致模型错过许多实际的正样本(即,将其标记为负)。最后的成本由经常伴随并补充精度的指标(即召回率)捕获。

recall 是被模型归类为正的实际正例的比率:

recall= #TP /(#TP + #FN)

通过示例数据,我们得到以下结果:

Recall = 4 /(4 + 2)= 66.7%。

在样本集中的所有阳性样本中,模型找到了多少?接受以降低某些指标得到较高的错误警报率通常会是一个有意识的决定。要计算此指标,您只需将所有示例声明为肯定值,以降低精度的代价为召回打分 100%。

正如我们所看到的,构建一个在准确性,召回性或精确度方面得分很高的系统相当容易。在实际的二进制分类问题中,通常很难同时获得良好的精度和召回率。(如果这样做很容易,则您会遇到一个简单的问题,可能不需要使用机器学习。)精度和召回率是在根本不确定什么是正确答案的棘手地方调整模型。您会看到更多细微和综合的指标,例如 X%召回率的 Precision。在下面的图 3.5 中,我们看到在经过 400 个训练周期后,当该模型的概率输出阈值设为 0.5 时,我们的网络钓鱼检测模型能够达到 96.8%的精确度和 92.9%的召回率。

# 图 3.5 针对网络钓鱼网页检测的模型训练得到的示例结果。请注意底部的各种指标:准确性,召回率和误报率(FPR)。曲线下面积(AUC)在 3.2.3 节中讨论。
figure3.5

正如我们在上面简要提到的那样,sigmoid 输出选择预测的阈值不必精确地为 0.5。实际上,根据情况,最好将其设置为大于 0.5(但小于 1)的值或小于 0.5(但大于 0)的值。降低阈值可使模型在将输入标记为正数时更加自由,这会导致较高的召回率,但可能会降低精度。另一方面,提高阈值会使模型在将输入标记为正数时更加谨慎,这通常会导致较高的精度,但可能会降低召回率。因此可以看出,在精确度和召回率之间要进行权衡,而这种权衡很难用到目前为止我们所讨论的任何一种指标来量化。幸运的是,二元分类研究为我们提供了更好的方法来量化和可视化这种折衷关系。我们下面讨论的 ROC 曲线就是用来评估的常用工具。

# 3.2.3 ROC 曲线:二元分类中权衡评估

ROC 曲线用于二分类或检测某些类型事件。它的全称是 Receiver Operating Characteristics,这是早期术语。如今,您几乎再也看不到扩展名称。下面的图 3.6 是我们应用程序的示例 ROC 曲线。

# 图 3.6 在网络钓鱼检测模型训练期间绘制的一组示例 ROC。每条曲线代表不同的训练周期。曲线显示随着训练的进行,二分类模型的质量逐渐提高。
figure3.6 如您在图3.6的轴标签中可能已经注意到的,ROC曲线并不是通过绘制精度和召回率指标而成的。相反,它基于两个略有不同的指标。ROC曲线的横轴是假阳性率(FPR),其定义为

FPR = #FP /(#FP + #TN)

ROC 曲线的垂直轴是真阳性率(TRP),其定义为

TPR = #TP /(#TP + #FN)=recall

真实阳性率(TPR)与召回的定义完全相同;对于同一指标,它只是一个不同的名称。但是,误报率(FPR)是新事物。其分母是示例实际类别为负的所有情况的计数;它的分子是所有假阳性的计数。换句话说,FPR 是被错误分类为正占实际否定示例的比率,这是通常被称为“错误警报”的概率。表 3.3 总结了二进制分类问题中最常见的指标。

# 表 3.3 二进制分类问题的常用指标。
指标名称 定义 如何在 ROC 或精度/召回曲线中使用
准确性 (#TP + #TN)/(#TP + #TN +#FP + #FN) ROC 不使用
精确度 #TP /(#TP + #FP) precision/recall 曲线的垂直轴
召回率/灵敏度/真实率(TPR) #TP /(#TP + #FN) ROC 曲线的垂直轴(例如,见图 3.6); precision/recall 曲线的水平轴。
误报率(FPR) #FP /(#FP + #TN) ROC 曲线的水平轴(例如,见图 3.6)
曲线下面积(AUC) 通过 ROC 曲线下的数值积分计算。有关示例,请参见下面的清单 3.6。 (ROC 不使用,而是由 ROC 计算得出的。)

图 3.6 中的 ROC 曲线是在从 1 个周期(“ epoch 001” )到 400 个训练周期(“ epoch 400”)七个不同的训练时期绘制的曲线。根据模型对测试数据(而不是训练数据)的预测来创建每条曲线。下面的代码清单 3.5 详细说明了如何通过 Model.fit ()API 的 onEpochBegin 回调完成此操作。这种方法使您可以在训练调用期间对模型执行分析和可视化,而无需编写 for 循环或使用多个 Model.fit ()调用。

# 代码 3.5 使用 onEpochBegin 回调在模型训练中的 ROC 曲线。
await model.fit(trainData.data, trainData.target, {
  batchSize,
  epochs,
  validationSplit: 0.2,
  callbacks: {
    onEpochBegin: async epoch => {
      if ((epoch + 1) % 100 === 0 || epoch === 0 || epoch === 2 || epoch === 4) {
        // ← Draw ROC every a few epochs.
        const probs = model.predict(testData.data);
        drawROC(testData.target, probs, epoch);
      }
    },
    onEpochEnd: async (epoch, logs) => {
      await ui.updateStatus(`Epoch ${epoch + 1} of ${epochs} completed.`);
      trainLogs.push(logs);
      ui.plotLosses(trainLogs);
      ui.plotAccuracies(trainLogs);
    }
  }
});

drawROC ()函数的主体包含如何制作 ROC 的详细信息(请参见清单 3.6)。它的作用是:

  1. 改变神经网络的 sigmoid 输出(概率)的阈值以获得不同的分类结果集
  2. 对于每个分类结果,将其与实际标签结合使用以计算 TPR 和 FPR。
  3. 将 TPR 与 FPR 相对应以形成 ROC 曲线 如上图 3.6 所示,在训练的最开始(训练周期为 001),由于模型的权重是随机初始化的,因此 ROC 曲线非常接近将点(0,0)与点( 1 1)。这就是随机猜测的样子。随着训练的进行,ROC 曲线向左上角推高,FPR 接近于 0,TPR 接近 1。如果专注于 FPR 值为 0.1 时,随着训练的进行,我们看到相应的 TPR 值单调增加。简而言之,这意味着随着训练的进行,如果我们错误警报(例如 FPR)在固定值上,我们可以实现越来越高的召回率(例如 TPR)。

“理想的” ROC 是向左上角弯曲最大的曲线,使其变成 Γ [56]形状。在那种情况下,您可以获得 100%TPR 和 0%FPR,这是二分类器的“圣杯”。但是,在实际问题中,我们只能改进模型以将 ROC 曲线推到更靠近左上角的位置,但是永远无法实现左上角的理论理想曲线。 被 ROC 曲线和 x 轴包围的单位面积,称为曲线下面积(AUC),由清单 3.6 中的代码计算得出。在考虑到误报与误报之间的权衡,此指标优于精度,召回率和准确性。用于随机猜测(即对角线)的 ROC 的 AUC 为 0.5,而 Γ 型理想 ROC 的 AUC 为 1.0。经过训练,我们的网络钓鱼检测模型的 AUC 达到 0.981。

# 代码 3.6 用于计算和渲染 ROC 曲线的代码,它还计算曲线下的面积(AUC)。
function drawROC(targets, probs, epoch) {
   return tf.tidy(() => {
     const thresholds = [
       0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45,
       0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85,
       0.9, 0.92, 0.94, 0.96, 0.98, 1.0
     ];  #A:
     const tprs = [];  // True positive rates.
     const fprs = [];  #False positive rates.
     let area = 0;
    for (let i = 0; i < thresholds.length; ++i) {       const threshold = thresholds[i];
      const threshPredictions = utils.binarize(probs, threshold).as1D();  #B:
       const fpr = falsePositiveRate(targets, threshPredictions).arraySync();  #C:
       const tpr = tf.metrics.recall(targets, threshPredictions).arraySync();
       fprs.push(fpr);
       tprs.push(tpr);

       if (i > 0) {  #D:
         area += (tprs[i] + tprs[i - 1]) * (fprs[i - 1] - fprs[i]) / 2;
       }
     }
     ui.plotROC(fprs, tprs, epoch);
     return area;
   });
 }

除了可视化二分类器的特征之外,ROC 还可以帮助我们选择实际情况下的概率阈值。例如,假设我们是一家将网络钓鱼检测程序作为服务开发的商业公司。我们是否要: a)使阈值相对较低,因为缺少真正的网络钓鱼网站将使我们在赔偿责任或合同损失方面付出很多代价
b)使阈值相对较高,因为我们更反对将其归类为“ 骗人”网址 而被正常网站屏蔽的用户提出的投诉。
每个阈值对应于 ROC 曲线上的一个点。当我们将阈值从 0 逐渐增加到 1 时,我们从图的右上角(FPR 和 TPR 均为 1)移动到图的左下角(FPR 和 TPR 均为 0)。在实际的工程问题中,在 ROC 曲线上选择哪个点的决定始终是基于权衡这种相对的实际成本,并且它可能因不同的客户和业务发展的不同阶段而异。 除 ROC 曲线外,另一种常用的二分类可视化是 precision-recall 曲线(有时称为 P / R 曲线,在表 3.3 中作了简要介绍)。与 ROC 曲线不同,精确-召回会根据召回率绘制精度。由于精度/召回曲线在概念上与 ROC 曲线相似,因此在此不再赘述。 上面清单 3.6 中值得一提的是 tf.tidy ()的使用。此函数可确保正确处理在匿名函数中作为参数传递给它的张量,以便它们不会继续占用 WebGL 内存。在浏览器中,TensorFlow.js 无法管理用户创建的张量的内存,这主要是由于 JavaScript 中缺少对象处理以及针对 TensorFlow.js 张量的 WebGL 缺乏垃圾收集。如果未正确清理此类中间张量,则将发生 WebGL 内存泄漏。如果这样的内存泄漏继续进行足够长的时间,他们将最终导致 WebGL 内存不足。附录 A4 中的 A4.3 节包含 TensorFlow.js 中有关内存管理的详细教程。同一附录的 A4.4 节中也有关于此主题的练习。

# 3.2.4 二元杂交熵:二元分类 的损失函数

到目前为止,我们已经谈过量化分类的不同方面,如准确度,精确度和召回(表 3.3)几个不同的指标。但是我们还没有讨论过一项重要的指标,该指标是可微的并且可以生成支持模型梯度下降训练的梯度。我们在上面的清单 3.4 中简要看到的 binaryCrossentropy ,尚未解释:

model.compile({
  optimizer: 'adam',
  loss: 'binaryCrossentropy',
  metrics: ['accuracy']
});

首先,有人可能会问:为什么我们不能简单地将准确性,准确性,召回率甚至 AUC 用作损失函数?毕竟,这些指标是可以理解的。同样,在我们之前看到的回归问题中,我们使用均方误差(一个相当容易理解的指标)作为直接训练的损失函数。答案是这些二分类指标都不能产生我们训练所需的梯度。以精度度量为例:要了解为什么它不是梯度友好的,计算精度需要确定模型预测中哪些是正的,哪些是负的(见表 3.3 中的第一行)。为此,必须应用阈值函数,该函数将模型的 sigmoid 输出转换为二进制预测。这是问题的症结在于:尽管阈值函数(或更专业的术语是“阶跃函数” )几乎在任何地方都是微分的(“几乎”是因为在“跳跃点”微分为 0.5),但导数始终是正好为零(见图 3.7)!如果我们尝试通过此阈值函数进行反向传播,会发生什么情况?您的梯度最终将变为全零,因为在某些时候,上游梯度值需要与阶跃函数的这些全零导数相乘。简而言之,如果选择精度(或准确度,召回率,AUC 等)作为损失,则基本阶跃函数的平坦部分将使训练过程无法知道在权重空间中的移动位置以减少权重损失值。

# 图 3.7 用于转换二分类模型的概率输出的阶跃函数几乎在任何地方都是微分的。但是,每个微分点处的梯度(导数)恰好为零。
figure3.7

因此,将精度用作损失函数无法使我们计算有用的梯度,从而使我们无法获得有意义的模型权重更新。对其他的度量标准,包括精度,召回率,误报率和 AUC,都有同样的限制。尽管这些指标对于人类理解二分类器的行为很有用,但它们对于这些模型的训练过程毫无用处。 我们用于二分类任务的损失函数是二进制交叉熵,它对应于网络钓鱼检测模型代码中的“ binaryCrossentropy ”配置(清单 3.4 和 3.5)。在算法上,我们可以使用以下代码定义二进制交叉熵:

# 代码 3.7 二进制交叉熵损失函数的代码[57]。
function binaryCrossentropy(truthLabel, prob):
   if truthLabel is 1:
    return -log(prob)
   else:
    return -log(1 - prob)

在上面的代码中,trueLabel 是一个数字,数值是 0 或 1,指示了输入示例实际上具有负(0)还是正(1)标签。prob 是模型预测的示例属于阳性类别的概率。请注意,与 trueLabel 不同,prob 应该是可以为 0 到 1 之间任何值的实数。log 是自然对数,以 e(2.718)为底。binaryCrossentropy 函数的主体包含 if-else 逻辑分支,该逻辑分支根据 trueLabel 是 0 还是 1 来执行不同的计算。图 3.8 在同一图中绘制了两种情况。 查看图 3.8 中的曲线时,较低的值更好,因为这是损失函数。关于损失函数要注意的重要事项是:

  • 如果 trueLabel 为 1,则 prob 值接近 1.0 会导致较低的损失函数值。这是有道理的,因为当示例实际为正时,我们希望模型输出的概率尽可能接近 1.0。反之亦然,如果 TruthLabel 为 0,则当概率值接近于 0 时,损失值就会降低。这也是有道理的,在这种情况下,我们希望模型输出的概率尽可能接近于 0。
  • 与图 3.7 中所示的二进制阈值函数不同,这些曲线在每个点处都具有非零的斜率,从而导致非零的梯度。这就是为什么它适用于基于反向传播的模型训练。
# 图 3.8 二进制交叉熵损失函数。分别绘制了两种情况(TruthLabel = 1 和 TruthLabel = 0),反映了清单 3.7 中的 if-else 逻辑分支。
figure3.8

您可能会问一个问题是“为什么不重复我们对回归模型所做的工作?只是假装 0-1 值是回归目标,并使用均方误差(MSE)作为损失函数”?毕竟,MSE 是可微的,如果我们计算 TruthLabel 和概率之间的 MSE,就像 binaryCrossentropy 一样,将得出非零导数。答案是这与 MSE 在边界处的“收益递减”有关。在表 3.4 中,我们列出了当 TruthLabel 为 1 时许多概率值的 binaryCrossentropy 和 MSE 损失值。随着 prob 接近 1(所需值),MSE 的下降速度与 binaryCrossentropy 越来越慢。当 prob 接近 1(例如 0.9)时,binaryCrossentropy 模型产生更高(即接近 1)的 prob 值。同样,当 TruthLabel 为 0 时,MSE 在生成将模型的概率输出推向 0 的梯度也不如 binaryCrossentropy 。 因此,这表明了二分类问题与回归问题不同的另一个方面:对于二进制分类问题,损失(binaryCrossentropy )和指标(准确性,精度等)不同,而对于回归而言其相同(例如 meanSquaredError )。我们将在下一节中看到,多类别分类问题也涉及不同的损失函数和度量。

# 表 3.4 比较二进制交叉熵和均方误差(MSE)的值,以获取假设的二进制分类结果。
truthLabel prob Binary cross entropy MSE
1 0.1 2.302 0.81
1 0.5 0.693 0.25
1 0.9 0.100 0.01
1 0.99 0.010 0.0001
1 0.999 0.001 0.000001
1 1 0 0
上次更新: 11/15/2020, 1:05:56 PM