二元交叉熵计算中的 pos_weight



当我们处理不平衡的训练数据(负样本较多,正样本较少)时,通常会使用pos_weight参数。pos_weight的期望是,当positive sample得到错误的标签时,模型将比negative sample得到更高的损失。当我使用binary_cross_entropy_with_logits函数时,我发现:

bce = torch.nn.functional.binary_cross_entropy_with_logits
pos_weight = torch.FloatTensor([5])
preds_pos_wrong =  torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
loss_pos_wrong = bce(preds_pos_wrong, label_pos, pos_weight=pos_weight)
preds_neg_wrong =  torch.FloatTensor([1.5, 0.5])
label_neg = torch.FloatTensor([0, 1])
loss_neg_wrong = bce(preds_neg_wrong, label_neg, pos_weight=pos_weight)

然而:

>>> loss_pos_wrong
tensor(2.0359)
>>> loss_neg_wrong
tensor(2.0359)

错误的正样本和负样本造成的损失是相同的,那么pos_weight在不平衡数据损失计算中是如何工作的呢?

;两个损失是相同的,因为你在计算相同的数量:两个输入是相同的,两个批元素和标签只是交换。


为什么你得到同样的损失?

我想你对F.binary_cross_entropy_with_logits的用法感到困惑(你可以找到nn.BCEWithLogitsLoss的更详细的文档页面)。在您的情况下,您的输入形状(又名模型的输出)是一维的,这意味着您只有一个logitx而不是两个) .

在你的例子中,你有

preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])

这意味着您的批大小为2,并且由于默认情况下该函数平均批元素的损失,因此您最终会得到BCE(preds_pos_wrong, label_pos)BCE(preds_neg_wrong, label_neg)的相同结果。你的批处理的两个元素刚刚交换了。

您可以很容易地验证这一点,不用使用reduction='none'选项平均批处理元素的损失:

>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos, 
pos_weight=pos_weight, reduction='none')
tensor([2.3704, 1.7014])
>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos, 
pos_weight=pos_weight, reduction='none')
tensor([1.7014, 2.3704])

查看F.binary_cross_entropy_with_logits:

也就是说二元交叉熵的公式是:

bce = -[y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]

其中y(分别sigmoid(x)为与该logit相关的正类,1 - y为与logit相关的正类。1 - sigmoid(x))为负类。

文档可以更精确地说明pos_weight的权重方案(不要与weight混淆,后者是不同logits输出的权重)。正如你所说的,pos_weight的想法是权衡积极的项,而不是整个项

bce = -[w_p*y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]

其中w_p为正项的权值,用以补偿正负样本不平衡。在实践中,这应该是w_p = #negative/#positive

因此:

>>> w_p = torch.FloatTensor([5])
>>> preds = torch.FloatTensor([0.5, 1.5])
>>> label = torch.FloatTensor([1, 0])

内置损耗函数,

>>> F.binary_cross_entropy_with_logits(preds, label, pos_weight=w_p, reduction='none')
tensor([2.3704, 1.7014])

与人工计算比较:

>>> z = torch.sigmoid(preds)
>>> -(w_p*label*torch.log(z) + (1-label)*torch.log(1-z))
tensor([2.3704, 1.7014])

相关内容

  • 没有找到相关文章

最新更新