如何通过索引编程修改二维张量的单个值



我有一个2D张量my_tensor大小[50,50]和数据类型int32,我需要在一个特定位置增加值。要更新的位置的索引由2个整数张量给出,它们分别给出轴0和轴1中的位置:

idx_0 is:
tf.Tensor([27], shape=(1,), dtype=int32)
idx_1 is:
tf.Tensor([26], shape=(1,), dtype=int32)

Tensorflow的tensor_scatter_nd_ad似乎是解决方案。如果我手动定义索引,代码就可以工作,但如果我尝试使用idx_0和idx_1,每个实现都会出现一些索引/维度不匹配的错误。

这起作用,增加位置(27,26(:

tf.tensor_scatter_nd_add(reversals_count, [[27, 26]], [1])

但这引发了一个错误:

tf.tensor_scatter_nd_add(reversals_count, [[idx_0, idx_1]], [1])

带有错误消息

{InvalidArgumentError}Outer dimensions of indices and update must match. Indices shape: [1,2,1], updates shape:[1] [Op:TensorScatterAdd]

如何使用idx_0idx_1张量来代替[[27, 26]]?我尝试过的其他语法也没有产生正确的维度:

[[idx_0], [idx_1]]
tf.concat([idx_0, idx_1], axis=0)

为了简单起见,我创建了一个(5x5(单位矩阵,并在索引(0,0(、(1,1(、(2,2(、(3,3((即前四个对角元素(处更新它们。首先将索引定义为张量,然后将值定义为将在各个索引处的值相加的张量,然后使用";tf.tensor_scatter_nd_ad"命令您可以对(50X50(矩阵执行类似操作。谢谢

import tensorflow as tf
indices = tf.constant([[0,0], [1,1], [2,2],[3,3]]) # updating at diagonal index elements, you can see the change
updates = tf.constant([9, 10, 11, 12])# values that will add up at respective indexes
print("original tensor is ")
tensor = tf.ones([5,5], dtype=tf.int32)
print(tensor)
print("updated tensor is ")
updated = tf.tensor_scatter_nd_add(tensor, indices, updates)
print(updated)

相关内容

  • 没有找到相关文章

最新更新