使用tf时形状错误不兼容.Map_fn将python函数应用于张量



在构建一些代码来训练tensorflow深度模型时,我使用tensorflow tf。Map_fn和tf.py_function作为包装器应用scipy python函数作为损失函数映射形状为[batch_size,num_classes]的2个概率向量p和q的批中的每2行。当在这批向量(p,q)上使用KL_divergence时,这种计算可以很好地训练,并且不存在形状不兼容问题:

tf.reduce_sum(p*(tf.log(p + 1e-16) - tf.log(q + 1e-16)), axis=1) #KL divergence

然而,当我尝试使用Wasserstein距离或scipy中的energy_distance函数时,我得到了处理不兼容形状[]和[5000]的错误。5000是这里的类的数量(p和q的形状[batch_size, 5000])

import tensorflow as tf
def compute_kld(p_logit, q_logit, divergence_type):
p = tf.nn.softmax(p_logit)
q = tf.nn.softmax(q_logit)
if divergence_type == "KL_divergence":
return tf.reduce_sum(p*(tf.log(p + 1e-16) - tf.log(q + 1e-16)), axis=1)
elif divergence_type == "Wasserstein_distance":
def wasserstein_distance(x,y):
import scipy
from scipy import stats
return stats.wasserstein_distance(x,y)
@tf.function
def func(p,q):
return tf.map_fn(lambda x: tf.py_function(func=wasserstein_distance, inp=[x[0], x[1]], Tout=tf.float32), (p, q), dtype=(tf.float32)) #, parallel_iterations=10)
return func(p, q)
elif divergence_type == "energy_distance": # The Cramer Distancedef energy_distance(x,y):
def energy_distance(x,y):
import scipy
from scipy import stats
return stats.energy_distance(x,y)
@tf.function
def func(p,q):
return tf.map_fn(lambda x: tf.py_function(func=energy_distance, inp=[x[0], x[1]], Tout=tf.float32), (p, q), dtype=(tf.float32)) #, parallel_iterations=10)
return func(p, q)

这是用一批5个和3个类测试损失函数的代码,它们都单独工作得很好:

import tensorflow as tf
p = tf.constant([[1, 2, 3], [1, 2, 3], [14, 50, 61], [71, 83, 79], [110,171,12]])
q = tf.constant([[1, 2, 3], [1.2, 2.3, 3.2], [4.2, 5.3, 6.4], [7.5, 8.6, 9.4], [11.2,10.1,13]])
p = tf.reshape(p, [-1,3])
q = tf.reshape(q, [-1,3])
p = tf.cast(p, tf.float32)
q = tf.cast(q, tf.float32)
with tf.Session() as sess:
divergence_type = "KL_divergence"
res = compute_kld(p, q, divergence_type = divergence_type)

divergence_type = "Wasserstein_distance"
res2 = compute_kld(p, q, divergence_type = divergence_type)

divergence_type = "energy_distance"
res3 = compute_kld(p, q, divergence_type = divergence_type)
print("############################## p")   
print(sess.run(tf.print(p)))
print("##")
print(sess.run(tf.print(tf.shape(p))))
print("############################## KL_divergence")   
print(sess.run(tf.print(res)))
print("##")
print(sess.run(tf.print(tf.shape(res))))
print("############################## Wasserstein_distance")   
print(sess.run(tf.print(res2)))
print("##")
print(sess.run(tf.print(tf.shape(res2))))
print("############################## energy_distance")   
print(sess.run(tf.print(res3)))
print("##")
print(sess.run(tf.print(tf.shape(res3))))

输出:

############################## p
[[1 2 3]
[1 2 3]
[14 50 61]
[71 83 79]
[110 171 12]]
None
##
[5 3]
None
############################## KL_divergence
[0 0.000939823687 0.367009342 1.1647588 3.09911442]
None
##
[5]
None
############################## Wasserstein_distance
[0 0.0126344115 0.204870835 0.237718046 0.120362818]
None
##
[5]
None
############################## energy_distance
[0 0.0917765796 0.41313991 0.438246906 0.316672504]
None
##
[5]
None

然而,当使用wasserstein距离或能量距离在我的训练代码,我得到不兼容的形状误差:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Tried to set a tensor with incompatible shape at a list index. Item element shape: [] list shape: [5000]
[[{{node gradients/TensorArrayV2Read/TensorListGetItem_grad/TensorListSetItem}}]]

我想知道如果dtype为tf。map_fn或tf.py_function我使用是错误的,或者如果我必须指定/强加形状的地方?

这里是整个代码的链接,我试图用方法"compute_kld": https://github.com/shenyuanyuan/IMSAT/blob/master/imsat_cluster.py中的Wasserstein距离替换KL-divergence

提前感谢您的帮助!

== UPDATE ==

我检查了所有提供的批次,p和q的形状似乎是正确的

shape(p)
(?, 5000)
shape(q)
(?, 5000)

但是,func返回对象的类型是。因此,我尝试用:

来重塑它
return tf.reshape(func(p, q), [p.shape[0]])

然而,这似乎并没有改变什么,因为错误仍然是相同的。在提供第一批之后,代码在开始处理第二批之前崩溃。

在没有看到你的训练代码的情况下,我能帮助你的是获取文档并尝试揭示一些信息。

map_fn通过对每个在轴0上未堆叠的元素应用fn来转换元素。

如果elements是张量的元组(或嵌套结构),则这些张量必须具有相同的外维大小(num_elems);fn用于从元素转换相应切片的每个元组(或结构)。例如,如果elems是一个元组(t1, t2, t3),则使用fn变换每个切片元组(t1[i], t2[i], t3[i])(其中0 <= i <num_elems)。>

energy_distance计算两个1D分布之间的能量距离。

wasserstein_distance计算两个1D分布之间的第一个Wasserstein距离。

首先,您应该确保仅将2Dp_logitq_logit传递给compute_kld

最新更新