我一直能够使用tf.tensor_scatter_nd_update
写入张量,而不会遇到任何问题,但我无法理解为什么它不适用于某些特定的张量。
举一个简单的例子,比如说我想基于布尔掩码mask=[[1 0 1]]
设置input=[[0 0 0]]
到update=[[1 2 3]]
中的某些值。我会简单地做:
input=tf.tensor_scatter_nd_update(input,tf.where(mask),update)
期望运算的结果是CCD_ 5。
相反,我得到了
ValueError: Dimensions [2,2) of input[shape=[1,3]] = [] must match dimensions [1,2) of updates[shape=[1,3]] = [3]: Shapes must be equal rank, but are 0 and 1 for ... with input shapes: [1,3], [?,2], [1,3].
我真的不知道出了什么问题;我总是能够毫无问题地使用该函数,即使在更复杂的情况下也是如此。
我想明白了。
问题的一部分确实是tf.where()
返回了一个2-D张量,但这是因为我使用它来生成updates
向量:
input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.where(something_else))
解决方案是通过以下方式去除多余的尺寸:
input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.squeeze(tf.where(something_else)))