克拉斯中心损失



我想在Keras中实现[http://ydwen.github.io/papers/WenECCV16.pdf]中解释的中心损失

我开始创建一个有2个输出的网络,如:

inputs = Input(shape=(100,100,3))
...
fc = Dense(100)(#previousLayer#)
softmax = Softmax(fc)
model = Model(input, output=[softmax, fc])
model.compile(optimizer='sgd', 
              loss=['categorical_crossentropy', 'center_loss'],
              metrics=['accuracy'], loss_weights=[1., 0.2])

首先,这样做是一个好方法吗?

其次,我不知道如何在keras中实现center_loss。Center_loss看起来像均方误差,但不是将值与固定标签进行比较,它将值与每次迭代时更新的数据进行比较。

谢谢你的帮助

对于我来说,您可以按照以下步骤实现此层:

  1. 写一个自定义图层ComputeCenter

    • 接受两个输入:i)真值标记y_true(不是单热编码,而是整数)和ii)预测隶属度y_pred

    • 包含一个大小为num_classes x num_feats数组的查找表W作为可训练的权重(参考BatchNormalization Layer), W[j]是第j类特征的移动平均的占位符。

    • 按文中规定计算中心损耗。

    • 输出结果距离数组D
  2. 要计算中心损耗,需要

    • 我)。根据y_true[k]=j使用y_pred[k]更新W[j]
    • ii)。检索y_true[k]=j
    • y_pred[k]的中心特征c_true[k]=W[j]
    • iii)计算y_predc_true之间的距离
    • 这里c_true[k] = W[j], k为样本索引,j为y_pred[k]的ground truth label。
  3. 使用model.add_loss()计算该损失。注意,不要在model.compile( loss = ... )中添加这个损失。

最后,如果需要,您可以在中心-loss中添加一些损耗系数。

相关内容

  • 没有找到相关文章

最新更新