一个小例子来说明我需要什么
我有一个关于在tensorflow中聚集的问题。假设我有一个张量(出于某种原因我很关心):
test1 = tf.round(5*tf.random.uniform(shape=(2,3)))
得到如下输出:
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 2.],
[4., 5., 0.]], dtype=float32)>
我还有一个下标张量列下标我想在每一行上挑出来
test_ind = tf.constant([[0,1,0,0,1],
[0,1,1,1,0]], dtype=tf.int64)
我想收集这个,以便从第一行(第0行)中,我挑选列0,1,0,0,1中的项,第二行也是如此。
所以这个例子的输出应该是:<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[1., 1., 1., 1., 1.],
[4., 5., 5., 5., 4.]], dtype=float32)>
<<p>我尝试/strong>所以我想出了一种方法来做这件事,我写了下面的函数gather_matrix_indices(),它将接受一个值的张量和一个指标的张量,并做我上面指定的事情。
def gather_matrix_indices(input_arr, index_arr):
row, _ = input_arr.shape
li = []
for i in range(row):
li.append(tf.expand_dims(tf.gather(params=input_arr[i], indices=index_arr[i]), axis=0))
return tf.concat(li, axis=0)
我的问题
我只是想知道,是否有一种方法来做到这一点只使用tensorflow或numpy方法?我能想到的唯一解决办法是写我自己的函数,遍历每一行,收集该行所有列的索引。我还没有遇到运行时问题,但我宁愿在可能的情况下使用内置的tensorflow或numpy方法。我试过了。以前也收集过,但我不知道这种特殊情况是否可能与tf的任何组合。Gather和tf.gather_nd。如果有人有什么建议,我将不胜感激。
编辑(08/18/22)
我想在PyTorch中添加一个编辑,调用torch.gather()
并在参数中设置dim=1
将完全符合我在这个问题中想要的。所以,如果你熟悉这两个库,你真的需要这个功能,torch.gather()
可以做到这一点。
您可以使用gather_nd()
。要让这个工作看起来有点棘手。让我试着用形状来解释这个。
得到test1 -> [2, 3]
和test_ind_col_ind -> [2, 5]
。test_ind_col_ind
只有列索引,但是您还需要行索引来使用gather_nd()
。为了使用gather_nd()
和[2,3]
张量,我们需要创建一个test_ind -> [2, 5, 2]
大小的张量。这个新的test_ind
的最内层维度对应于您想要从test1
中索引的单个索引。这里我们有inner most dimension = 2
格式的(<row index>, <col index>)
。换句话说,看看test_ind
的形状,
[ 2 , 5 , 2 ]
| |
V |
(2,5) | <- The size of the final tensor
V
(2,) <- The full index to a scalar in your input tensor
import tensorflow as tf
test1 = tf.round(5*tf.random.uniform(shape=(2,3)))
print(test1)
test_ind_col_ind = tf.constant([[0,1,0,0,1],
[0,1,1,1,0]], dtype=tf.int64)[:, :, tf.newaxis]
test_ind_row_ind = tf.repeat(tf.range(2, dtype=tf.int64)[:, tf.newaxis, tf.newaxis], 5, axis=1)
test_ind = tf.concat([test_ind_format, test_ind], axis=-1)
res = tf.gather_nd(indices=test_ind, params=test1)