Keras后端开关结合tf.没有按预期工作的地方



我有一个自定义损失函数,其中我想将值从基于一个热的编码更改为一定范围内的值,以计算IOU。

这段代码的一部分是看我在张量中哪里有一个1,其他地方有0。这里我用tf。Where返回位置。我有一个形状为[batch_size,S1,S2,12]的向量,我只关心最后一个维度,这就是为什么我取[…],2]

现在它经常发生,我的预测是全零,因为我的背景事件没有任何值,而且我的网络会预测一个全零向量。这表示tf。Where将返回一个空张量。这就是为什么我想用K.switch来检查张量是否为空,因为如果是空的,我想要返回0。

现在的问题是k.s switch期望then else选项的形状具有相同的形状,但我需要我的输出具有shape [batch_size,S1,S2,1]。我试过不同的东西,但我不能得到这个工作。我需要得到形状[batch_size,S1,S2,1]的零,或者我需要wher_box1具有浮点数[batch_size,S1,S2,1]

现在实现的方式是,当where_box1_temp为空时,K.switch返回一个0的空向量,这不是我想要的。当我使用tf.zero ([batch_size,S1,S2,1])时,它会抱怨当where_box1_temp为空时条件的形状不同....

where_box1_temp = tf.where(y_pred[...,C+1:C+13])[...,2]
where_box1 = K.switch(tf.equal(tf.size(where_box1_temp),0) , 
tf.zeros_like(where_box1_temp) , where_box1_temp)

所以我找到了一个变通办法,也许这对别人有帮助:

where_box1_temp = tf.where(y_pred[...,C+1:C+13],[1,2,3,4,5,6,7,8,9,10,11,12],0)
where_box1 = tf.reshape(K.sum(where_box1_temp,axis=3),[batch_size,5,5])

这允许我有一个我想要的形状的张量,其中所有背景/零预测值都是0,而不必使用k.switch,也不会遇到任何空维度或类似的问题。

最新更新