根据列索引矩阵(tensorflow/numpy)收集矩阵中的条目



一个小例子来说明我需要什么

我有一个关于在tensorflow中聚集的问题。假设我有一个张量(出于某种原因我很关心):

test1 = tf.round(5*tf.random.uniform(shape=(2,3)))

得到如下输出:

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 2.],
[4., 5., 0.]], dtype=float32)>

我还有一个下标张量列下标我想在每一行上挑出来

test_ind = tf.constant([[0,1,0,0,1],
[0,1,1,1,0]], dtype=tf.int64)

我想收集这个,以便从第一行(第0行)中,我挑选列0,1,0,0,1中的项,第二行也是如此。

所以这个例子的输出应该是:
<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[1., 1., 1., 1., 1.],
[4., 5., 5., 5., 4.]], dtype=float32)>
<<p>

我尝试/strong>所以我想出了一种方法来做这件事,我写了下面的函数gather_matrix_indices(),它将接受一个值的张量和一个指标的张量,并做我上面指定的事情。

def gather_matrix_indices(input_arr, index_arr):
row, _ = input_arr.shape

li = []

for i in range(row):
li.append(tf.expand_dims(tf.gather(params=input_arr[i], indices=index_arr[i]), axis=0))

return tf.concat(li, axis=0)

我的问题

我只是想知道,是否有一种方法来做到这一点只使用tensorflow或numpy方法?我能想到的唯一解决办法是写我自己的函数,遍历每一行,收集该行所有列的索引。我还没有遇到运行时问题,但我宁愿在可能的情况下使用内置的tensorflow或numpy方法。我试过了。以前也收集过,但我不知道这种特殊情况是否可能与tf的任何组合。Gather和tf.gather_nd。如果有人有什么建议,我将不胜感激。

编辑(08/18/22)

我想在PyTorch中添加一个编辑,调用torch.gather()并在参数中设置dim=1将完全符合我在这个问题中想要的。所以,如果你熟悉这两个库,你真的需要这个功能,torch.gather()可以做到这一点。

您可以使用gather_nd()。要让这个工作看起来有点棘手。让我试着用形状来解释这个。

得到test1 -> [2, 3]test_ind_col_ind -> [2, 5]test_ind_col_ind只有列索引,但是您还需要行索引来使用gather_nd()。为了使用gather_nd()[2,3]张量,我们需要创建一个test_ind -> [2, 5, 2]大小的张量。这个新的test_ind的最内层维度对应于您想要从test1中索引的单个索引。这里我们有inner most dimension = 2格式的(<row index>, <col index>)。换句话说,看看test_ind的形状,

[ 2 , 5 , 2 ]
|     |
V     |
(2,5)   |       <- The size of the final tensor   
V
(2,)     <- The full index to a scalar in your input tensor
import tensorflow as tf
test1 = tf.round(5*tf.random.uniform(shape=(2,3)))
print(test1)
test_ind_col_ind = tf.constant([[0,1,0,0,1],
[0,1,1,1,0]], dtype=tf.int64)[:, :, tf.newaxis]
test_ind_row_ind = tf.repeat(tf.range(2, dtype=tf.int64)[:, tf.newaxis, tf.newaxis], 5, axis=1)
test_ind = tf.concat([test_ind_format, test_ind], axis=-1)
res = tf.gather_nd(indices=test_ind, params=test1)

最新更新