损失函数tf.nn.weighted_cross_entropy_with_logits出现故障



我正在尝试用二进制目标训练u-net网络。通常的二进制交叉熵损失表现不佳,因为标签非常不平衡(0像素比1像素多得多(。所以我想更多地惩罚假阴性。但是tensorflow并没有现成的加权二进制交叉熵。由于我不想从头开始写损失,我尝试使用tf.nn.weighted_cross_entropy_with_logits。为了能够容易地将损失反馈给model.compile函数,我正在编写这个包装器:

def loss_wrapper(y,x):
x = tf.cast(x,'float32')
loss = tf.nn.weighted_cross_entropy_with_logits(y,x,pos_weight=10)
return loss

然而,无论将x铸造为浮动,我仍然会得到错误:

TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type int32 of argument 'x'.

当调用tf loss时。有人能解释一下发生了什么吗?

如果x代表您的预测。它可能已经具有类型float32。我认为你需要铸造y,这可能是你的标签。因此:

loss = tf.nn.weighted_cross_entropy_with_logits(tf.cast(y, dtype=tf.float32),x,pos_weight=10)

最新更新