交叉熵损失:揭示其在机器学习中的作用
损失函数是一种数学方法,用于通过量化实际值和预测值之间的差异来衡量机器学习模型的准确性。差异越小,机器学习模型越好。因此,损失函数作为评估模型性能和指导改进的指标。
交叉熵损失,或对数损失,通常用于训练分类模型。这是一个易于实现且经过优化的损失函数,需要将标签编码为数值以准确计算损失。本文将讨论交叉熵损失的功能和实现,其应用、挑战和提高性能的技巧。
理解交叉熵损失
交叉熵损失比较预测值的概率分布与实际值。惩罚对于交叉熵损失函数至关重要,这是最小化机器学习模型损失的核心。
两个主要的分类问题:
- 二元分类:只有两个标签的输出变量的分类任务。
- 多类分类:有两个以上标签的分类任务。
假设对于二元分类,类别标签被转换为0或1的值。在多类场景中,假设标签是独热编码的。例如,对于一个有三个类别的分类问题,如果一个数据点属于第一个类别,它的标签将被表示为[1, 0, 0]。同样,如果它属于第二个类别,标签将是[0, 1, 0]。
二元分类的交叉熵公式是:
19.1.PNG
其中: y = 实际标签 p = 正类的预测概率 多类分类的交叉熵公式是:
19.2.PNG
其中: Σ = 所有N个类别的总和 y_i = 第i类的真标签 p_i = 第i类的预测概率
交叉熵使用对数来计算惩罚。如果预测显著偏离实际值,该函数会对预测施加更大的惩罚。当偏差较小时,惩罚也会减少。
例如,考虑一个预测候选人性别的模型。如果它在实际值为1时为男性分配了0.9的概率,这个预测的惩罚将是0.105,计算为—ln(0.9) = 0.105。相反,如果模型为男性分配了0.1的概率,惩罚增加到2.30。
计算交叉熵损失
计算交叉熵损失函数通常包括以下步骤:
- 推理时间:训练后,模型以每个类别的概率形式进行预测。
- 每个数据点的损失:交叉熵损失函数将预测概率与真实标签进行比较。损失值告诉我们模型对那个特定数据点的猜测有多“错”。
- 总体损失:但我们不仅仅关心一个数据点!为了了解模型的性能,我们计算所有数据点的平均交叉熵损失,通常在验证或测试数据集中。这是模型的总体损失。最小化损失:优化算法通过根据惩罚调整模型参数来最小化总体损失。
Python中的交叉熵损失
在您的Python环境中运行以下代码片段以计算交叉熵损失。Y_true指的是以数值编码的实际类别标签,y_pred是预测值: python
!pip install scikit-learn
import log_loss from sklearn
from sklearn.metrics import log_loss
通过向log_loss函数提供实际和预测值来计算损失
loss = log_loss(y_true, y_pred)
机器学习中的应用
交叉熵损失函数用于各种分类模型,包括:
逻辑回归
逻辑回归通常用于二元分类问题。它预测每个类别的概率在0和1之间。交叉熵应用于预测概率以衡量模型的性能。 例如,在以下代码片段中,我们使用numPy创建一个虚拟的y_true和y_pred,并计算其交叉熵:
python
from sklearn.metrics import log_loss
import numpy as np
示例数据
y_true = np.array([0, 1]) 真实类别标签
y_pred = np.array([[0.9, 0.1], 数据点0的预测概率
[0.5, 0.5]]) 数据点1的预测概率
loss = log_loss(y_true, y_pred)
print(f"交叉熵损失(使用log_loss): {loss}")
深度神经网络
想象一下,你正在训练一个深度神经网络(DNN)来识别手写数字(0-9)。以下是交叉熵损失与反向传播的工作原理:
预测时间:训练后,DNN接收一个数字图像,并预测它是每个数字的概率(10个概率,每个数字一个)。与我们刚刚讨论的简单分类模型不同,这里的概率是平滑给出的。将其视为一个光谱——数字的概率越高,模型就越有信心图像代表该数字。
每张图片的损失:交叉熵损失函数将这些预测概率与实际数字标签(例如,9的概率为0.9,其他所有数字的概率为0.1)进行比较。这告诉我们模型对那张图片的猜测有多“错”。
- 反向传播:一张图片并不能说明全部情况。这里变得强大:反向传播采用这种连续损失,并使用一种称为可微性的特有属性来计算DNN中的每个微小调整(权重和偏差)对那个错误贡献了多少。
- 从错误中学习:通过理解这些贡献,优化算法可以调整DNN中的每个单独权重,以最小化所有训练图像的总体损失。
- 这种方法教导深度神经网络根据分配的惩罚调整其权重,以生成可靠的预测。
挑战和技巧
当使用交叉熵损失作为损失函数时,必须考虑几个缺点。采取以下讨论的安全措施将确保有效的损失计算。
对异常值敏感 交叉熵损失受到异常值的严重影响,这可能导致模型过拟合。由于交叉熵损失比较概率,极端值可能会扭曲计算。这导致模型忽略潜在数据,优先拟合异常值。
处理异常值敏感性的技巧
在数据清洗的数据预处理阶段移除异常值有助于避免过拟合。然而,要小心不要移除合法的数据点。可以使用对异常值不那么敏感的鲁棒损失函数,如Huber损失,作为替代,以防止过拟合。
类别不平衡 当一个类别的数据点数量显著多于另一个类别时,就会发生类别不平衡。交叉熵损失发现学习少数类别很困难,因为多数类别拥有更多的数据点。正确预测多数类别是最小化总体损失的最简单(懒惰)方式。
处理类别不平衡的技巧
对少数类别进行过采样或对多数类别进行欠采样可以创建一个平衡的数据集。像L1和L2正则化这样的正则化技术可能有助于防止模型过拟合。此外,调整模型超参数,如学习率或类别权重,可能会引导模型更多地关注特定类别。
结论
交叉熵损失是一个重要且易于使用的分类模型损失函数。它指导优化算法调整模型权重,实现更好的性能。二元交叉熵和多类交叉熵是两种类型的交叉熵函数。两者都使用对数来分配对不正确模型预测的惩罚。
交叉熵是许多科技巨头在各种应用中使用的有力工具。掌握任何技能的关键在于实践。使用像TensorFlow Playground这样的工具来可视化其对神经网络的影响。探索更多资源和动手实验,以更深入地理解交叉熵损失。
15.2.JPEG
技术干货
艾瑞巴蒂看过来!OSSChat 上线:融合 CVP,试用通道已开放
有了 OSSChat,你就可以通过对话的方式直接与一个开源社区的所有知识直接交流,大幅提升开源社区信息流通效率。
2023-4-6技术干货
我决定给 ChatGPT 做个缓存层 >>> Hello GPTCache
我们从自己的开源项目 Milvus 和一顿没有任何目的午饭中分别获得了灵感,做出了 OSSChat、GPTCache。在这个过程中,我们也在不断接受「从 0 到 1」的考验。作为茫茫 AI 领域开发者和探索者中的一员,我很愿意与诸位分享这背后的故事、逻辑和设计思考,希望大家能避坑避雷、有所收获。
2023-4-14技术干货
当一个程序员决定穿上粉裤子
如何找到和你时尚风格相似的明星?AI + Milvus=?
2023-8-23