Tensorflow:在损失操作中按批次的某个条件标签进行过滤



我的网络中有两个标签批次。第一批是噪声标签,第二批是经过验证的标签。并非所有嘈杂标签都有经过验证的标签,但批次具有相同的大小。损失必须仅在经过验证的标签上计算。有没有办法过滤批次以便在损失中仅使用经过验证的标签?

这是我的损失定义:

def loss_clean_v1(label_output, label_verified):
loss_clean_value = tf.reduce_sum(tf.abs(tf.subtract(label_output, label_verified)))
# debug
loss_clean_value = tf.Print(loss_clean_value, [loss_clean_value], message="Loss label cleaning: ")
return loss_clean_value

我找到了解决问题的方法!

label是从清洗到一批嘈杂标签得到的一批输出标签。label_verified是一批经过验证的标签;并非每个嘈杂的标签都有经过验证的标签,所以我使用所有 -1np.full((1, 6012), -1)的假标签。

在我的损失中,必须仅使用经过验证的标签。 我需要过滤label_verified批次。

为此,我将tf.where()与布尔向量一起使用,大小为批处理中标签的数量,作为条件,如[False, True]

问题是 Tensorflow 只提供元素比较运算符,不会生成我需要的布尔向量。

为了得到这个向量,我改进了元素条件结果:condition_refined = tf.cast(tf.reduce_sum(tf.cast(condition, tf.int32), 1), tf.bool)

过滤的批处理结果为label_filtered = tf.where(condition_refined, label_verified, label)

如果条件False,我使用标签输出,因为我的损失减去了已验证标签和输出标签。所以输出标签 - 输出标签对损失的贡献为 0。

这是我的代码:

def loss_clean(label, label_verified):
# in np.full, 2 is my batch size 
condition = tf.not_equal(label_verified, np.full((2, 6012), -1))
# debug
condition = tf.Print(condition, [condition], message="Condition: ")
condition_refined = tf.cast(tf.reduce_sum(tf.cast(condition, tf.int32), 1), tf.bool)
# debug
condition_refined = tf.Print(condition_refined, [condition_refined], message="Condition_refined: ")
label_filtered = tf.where(condition_refined, label_verified, label)
# debug
label_filtered = tf.Print(label_filtered, [label_filtered], message="Label_filtered: ")
loss_clean_value = tf.reduce_sum(tf.abs(tf.subtract(label, label_filtered)))
# debug
loss_clean_value = tf.Print(loss_clean_value, [loss_clean_value], message="Loss label cleaning: ")
return loss_clean_value