假设我有一个形状为[B,D]
的张量a
,并且我有一个包含形状[B]
索引的列表I
。现在我想使用列表中的索引将张量扩展到具有M > B
[M,D]
形状。请注意,索引属于范围[0,M]
。具体来说,I
是从张量a
到另一个具有更大维度0
值的张量的行映射。此功能与函数tf.gather()
相反。 有人可以提出解决方案吗? 谢谢
tf.scatter_nd
与tf.gather_nd
相反。让我们通过一个往返示例来了解这一点,其中:
- 首先,我们创建一个形状为
((5,4,1,2,3))
的张量,其中除索引[1,2,0,0]
和[3,0,0,1]
处的元素外,所有元素均为零,它们分别使用tf.scatter_nd
[16, 12, 11]
和[3,0,0,1]
。 - 其次,我们做相反的事情,在生成的张量上应用
tf.gather_nd
以获得两个原始向量,即[16, 12, 11]
和[3,0,0,1]
。
updates = [[16, 12, 11],[18, 40, 37]]
indices = [[1,2,0,0], [3,0,0,1]]
shape = (5,4,1,2,3)
# first step - scatter
scat_tensor = tf.scatter_nd(indices=indices, updates=updates, shape=shape)
print(f"verification: expected {updates[0]}, got {scat_tensor[1,2,0,0]}")
# now the reverse step
reconstructed_updates = tf.gather_nd(scat_tensor, indices)
print(f"verification: expected {updates }, got {reconstructed_updates }")