tensor_scatter_nd_update值错误:形状必须相等,但为0和1



我一直能够使用tf.tensor_scatter_nd_update写入张量,而不会遇到任何问题,但我无法理解为什么它不适用于某些特定的张量。

举一个简单的例子,比如说我想基于布尔掩码mask=[[1 0 1]]设置input=[[0 0 0]]update=[[1 2 3]]中的某些值。我会简单地做:

input=tf.tensor_scatter_nd_update(input,tf.where(mask),update)

期望运算的结果是CCD_ 5。

相反,我得到了

ValueError: Dimensions [2,2) of input[shape=[1,3]] = [] must match dimensions [1,2) of updates[shape=[1,3]] = [3]: Shapes must be equal rank, but are 0 and 1 for ... with input shapes: [1,3], [?,2], [1,3].

我真的不知道出了什么问题;我总是能够毫无问题地使用该函数,即使在更复杂的情况下也是如此。

我想明白了。

问题的一部分确实是tf.where()返回了一个2-D张量,但这是因为我使用它来生成updates向量:

input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.where(something_else))

解决方案是通过以下方式去除多余的尺寸:

input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.squeeze(tf.where(something_else)))

相关内容

  • 没有找到相关文章

最新更新