如何使用tf.gather_nd在张量流中切片张量?



我正在寻找numpy中以下代码的张量流等效项。aidx_2都给出了。目标是构建b

# A float Tensor obtained somehow
a = np.arange(3*5).reshape(3,5)                    
# An int Tensor obtained somehow
idx_2 = np.array([[1,2,3,4],[0,2,3,4],[0,2,3,4]])  
# An int Tensor, constructed for indexing
idx_1 = np.arange(a.shape[0]).reshape(-1,1)        
# The goal
b = a[idx_1, idx_2]
print(b)
>>> [[ 1  2  3  4]
[ 5  7  8  9]
[10 12 13 14]]

我尝试直接索引张量并使用tf.gather_nd但我不断收到错误,所以我决定在这里问如何做到这一点。我到处寻找人们使用tf.gather_nd(因此标题)来解决类似问题的答案,但要应用此功能,我必须以某种方式重塑索引,以便它们可用于切片第一维。我该怎么做?请帮忙。

Tensorflow在NumPy中非常简单和Pythonic的事情时可能非常丑陋。以下是我如何使用tf.gather_nd在TensorFlow中重现您的问题。不过,可能有一个更好的方法可以做到这一点。

import tensorflow as tf
import numpy as np
with tf.Session() as sess:
# Define 'a'
a = tf.reshape(tf.range(15),(3,5))
# Define both index tensors 
idx_1 = tf.reshape(tf.range(a.get_shape().as_list()[0]),(-1,1)).eval()
idx_2 = tf.constant([[1,2,3,4],[0,2,3,4],[0,2,3,4]]).eval()
# get indices for use with gather_nd
gather_idx = tf.constant([(x[0],y) for (i,x) in enumerate(idx_1) for y in idx_2[i]])
# extract elements and reshape to desired dimensions
b = tf.gather_nd(a, gather_idx)
b = tf.reshape(b,(idx_1.shape[0], idx_2.shape[1]))
print(sess.run(b))
[[ 1  2  3  4]
[ 5  7  8  9]
[10 12 13 14]]