作用于自定义损失函数中的条件



我在tensorflow中定义了一个模型,作为损失函数,我编写了自己的自定义损失函数。通常,我在正常的MSE函数之外加上一个惩罚项:

return K.mean(K.square(y_true - y_pred)) + penalty

,其中这个惩罚是由一些输入元素和一些目标之间的乘积和给出的(我不报告代码,因为它与问题无关)。

现在,一切都很完美,但我需要向数据集添加一种新的示例,其中一些特征和输出为零值。

因此,在批处理中(在训练期间)两种类型的例子,即有零值和没有零值,我需要一种条件语句应用于整个批处理,这样我就可以根据条件的评估以两种不同的方式计算penalty

可以做到吗?

您可以使用tf.where:

例如:

  • 如果元素为零,返回自身
  • 如果元素不为零,返回自身+ 1
>>> batch_1 = tf.constant([0.,1.,0.,0.5,3.,-2.,0.])
>>> tf.where(batch_1 == 0, batch_1, batch_1+1)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0. , 2. , 0. , 1.5, 4. , -1. , 0. ], dtype=float32)>

它也适用于多维张量,例如:

  • 如果批处理的第二个元素为零,则对第一个元素求和
  • 否则,对整批
  • 求和
>>> a = tf.constant([[1,0,3],[3,4,9],,[8,0,1]],tf.float32)
>>> tf.where(a[:,1]==0,tf.reduce_sum(a[:,:1],axis=1), tf.reduce_sum(a,axis=1))
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 1., 16.,  8.], dtype=float32)>

最新更新