从张量流中的张量中提取值



如果有两个张量矩阵

a = [[1 2 3 4][5 6 7 8]]
b = [[0 1][1 2]],

我们怎么能得到这个:

c = [[1 2][6 7]]

即从第一行提取列0和1、从第二行提取列1和2。

这里有一种方法:

import tensorflow as tf
a = tf.constant([[1, 2, 3, 4],
[5, 6, 7, 8]])
b = tf.constant([[0, 1],
[1, 2]])
row = tf.range(tf.shape(a)[0])
row = tf.tile(row[:, tf.newaxis], (1, tf.shape(b)[1]))
idx = tf.stack([row, b], axis=-1)
c = tf.gather_nd(a, idx)
with tf.Session() as sess:
print(sess.run(c))

输出:

[[1 2]
[6 7]]

最新更新