函数 tf.gather() 的反函数



假设我有一个形状为[B,D]的张量a,并且我有一个包含形状[B]索引的列表I。现在我想使用列表中的索引将张量扩展到具有M > B[M,D]形状。请注意,索引属于范围[0,M]。具体来说,I是从张量a到另一个具有更大维度0值的张量的行映射。此功能与函数tf.gather()相反。 有人可以提出解决方案吗? 谢谢

tf.scatter_ndtf.gather_nd相反。让我们通过一个往返示例来了解这一点,其中:

  • 首先,我们创建一个形状为((5,4,1,2,3))的张量,其中除索引[1,2,0,0][3,0,0,1]处的元素外,所有元素均为零,它们分别使用tf.scatter_nd[16, 12, 11][3,0,0,1]
  • 其次,我们做相反的事情,在生成的张量上应用tf.gather_nd以获得两个原始向量,即[16, 12, 11][3,0,0,1]
updates = [[16, 12, 11],[18, 40, 37]]
indices = [[1,2,0,0], [3,0,0,1]]
shape = (5,4,1,2,3)
# first step - scatter
scat_tensor = tf.scatter_nd(indices=indices, updates=updates, shape=shape)
print(f"verification: expected {updates[0]}, got {scat_tensor[1,2,0,0]}")
# now the reverse step
reconstructed_updates = tf.gather_nd(scat_tensor, indices)
print(f"verification: expected {updates }, got {reconstructed_updates }")

最新更新