如何在 TF 中使用 (n-k)-dims 张量索引 n-dims 张量?



我有一个形状为[7, 7, 2, 4]tensor A和一个形状为[7, 7]tensor B

Tensor Btensor Aargmax,其价值是0,1

我想从 A 和 B 获得形状[7, 7, 4][7, 7, 1, 4]tensor C

规则是(i, j)元素tensor Btensor A的第 2 维的索引。

我怎样才能快速完成?我试图通过A[B]获得 C,但它不起作用。谁能帮我?谢谢。

好的,我被用来解决这个问题tf.gather_nd:

tensor_C = tf.gather_nd(tensor_A, tf.expand_dims(tf.argmax(tensor_B, 2), 2), batch_dims=3)

最新更新