如何正确使用交叉熵损失与Softmax进行分类



我想使用Pytorch训练一个多类分类器。

下面的官方Pytorch文档展示了如何在类型为nn.Linear(84, 10)的最后一层之后使用nn.CrossEntropyLoss()

不过,我记得Softmax就是这么做的。

这让我很困惑。


  1. 如何训练;标准";分类网络的最佳方式是什么
  2. 如果网络有最后一个线性层,如何推断每个类别的概率
  3. 如果网络有一个最终的softmax层,如何训练网络(哪个损失,如何)

我在Pytorch论坛上找到了这个线程,它可能回答了所有这些问题,但我无法将其编译成可工作且可读的Pytorch代码。


我假设的答案:

就像医生说的
  • 线性层的输出的指数,这些输出实际上是logits(对数概率)
  • 我不明白
  • 我认为理解softmax和交叉熵很重要,至少从实用的角度来看是这样。一旦你掌握了这两个概念,那么就应该清楚它们可能是什么";正确地";用于ML.的上下文

    交叉熵H(p,q)

    交叉熵是一个比较两种概率分布的函数。从实践的角度来看,可能不值得探讨交叉熵的形式动机,尽管如果你感兴趣,我推荐Cover和Thomas的《信息理论的元素》作为入门文本。这个概念很早就被引入了(我相信第二章)。这是我在研究生院使用的介绍文本,我认为它做得很好(当然我也有一个很棒的老师)。

    需要注意的关键是,交叉熵是一个函数,它以两个概率分布q和p作为输入,并在q和p相等时返回一个最小值。q表示估计分布,p表示真实分布。

    在ML分类的上下文中,我们知道训练数据的实际标签,因此真实/目标分布p对于真实标签的概率为1,而在其他地方的概率为0,即p是一个热向量。

    另一方面,估计的分布(模型的输出)q通常包含一些不确定性,因此q中任何类别的概率都在0到1之间。通过训练系统使交叉熵最小化,我们告诉系统,我们希望它尝试使估计的分布尽可能接近真实分布。因此,您的模型认为最有可能的类是与q的最高值相对应的类。

    Softmax

    同样,有一些复杂的统计方法来解释softmax,我们不会在这里讨论。从实用的角度来看,关键是softmax是一个函数,它以一个无界值列表作为输入,并输出一个有效的概率质量函数,保持相对有序。重要的是要强调关于相对排序的第二点。这意味着softmax的输入中的最大元素对应于softmax的输出中的最大元件。

    考虑一个被训练为最小化交叉熵的softmax激活模型。在这种情况下,在softmax之前,模型的目标是为正确的标签产生可能的最高值,为不正确的标签生成可能的最低值。

    PyTorch中的交叉熵损失

    PyTorch中交叉熵损失的定义是softmax和交叉熵的结合。特别是

    交叉熵损失(x,y):=H(one_hot(y),softmax(x))

    请注意,one_hot是一个获取索引y并将其扩展为one-hot向量的函数

    等价地,您可以将CrossEntropyLoss公式化为LogSoftmax和负对数似然损失(即PyTorch中的NLLLoss)的组合

    LogSoftmax(x):=ln(softmax(x))

    交叉熵损失(x,y):=NLLLoss(LogSoftmax(x),y)

    由于softmax中的幂运算;技巧";这使得直接使用CrossEntropyLoss比分阶段计算更稳定(更准确,不太可能得到NaN)。

    结论

    根据以上讨论,您的问题的答案是

    1.如何训练";标准";分类网络的最佳方式是什么

    就像医生说的。

    2.如果网络有一个最终的线性层,如何推断每个类别的概率

    将softmax应用于网络的输出,以推断每个类别的概率。如果目标只是找到相对排序或最高概率类,那么只需将argsort或argmax直接应用于输出(因为softmax保持相对排序)。

    3.如果网络有一个最终的softmax层,如何训练网络(哪个损失,如何)

    通常,出于上述稳定性原因,您不希望训练输出softmaxed输出的网络。

    也就是说,如果出于某种原因绝对需要,您可以获取输出日志并将其提供给NLLLoss

    criterion = nn.NLLLoss()
    ...
    x = model(data)    # assuming the output of the model is softmax activated
    loss = criterion(torch.log(x), y)
    

    这在数学上等同于将CrossEntropyLoss与使用softmax激活的模型一起使用。

    criterion = nn.CrossEntropyLoss()
    ...
    x = model(data)    # assuming the output of the model is NOT softmax activated
    loss = criterion(x, y)
    

    最新更新