我有一个形状为[7, 7, 2, 4]
的tensor A
和一个形状为[7, 7]
的tensor B
。
Tensor B
是tensor A
的argmax
,其价值是0,1
。
我想从 A 和 B 获得形状[7, 7, 4]
或[7, 7, 1, 4]
的tensor C
。
规则是(i, j)元素tensor B
是tensor A
的第 2 维的索引。
我怎样才能快速完成?我试图通过A[B]
获得 C,但它不起作用。谁能帮我?谢谢。
好的,我被用来解决这个问题tf.gather_nd:
tensor_C = tf.gather_nd(tensor_A, tf.expand_dims(tf.argmax(tensor_B, 2), 2), batch_dims=3)