我在python上使用tensorflow我有一个形状为 [?, 5, 37] 的数据张量和一个形状为 [?, 5] 的 idx 张量
我想从数据中提取元素并获得形状 [?, 5] 的输出,例如:
output[i][j] = data[i][j][idx[i, j]] for all i in range(?) and j in range(5)
看起来 tf.gather_nd() 函数最接近我的需求,但我看不到如何使用它......
谢谢!
编辑:我设法用gather_nd做到了,如下所示,但是有更好的选择吗?(好像有点粗暴)
nRows = tf.shape(length_label)[0] ==> ?
nCols = tf.constant(MAX_LENGTH_INPUT + 1, dtype=tf.int32) ==> 5
m1 = tf.reshape(tf.tile(tf.range(nCols), [nRows]),
shape=[nRows, nCols])
m2 = tf.transpose(tf.reshape(tf.tile(tf.range(nRows), [nCols]),
shape=[nCols, nRows]))
indices = tf.pack([m2, m1, idx], axis=-1)
# indices should be of shape [?, 5, 3] with indices[i,j]==[i,j,idx[i,j]]
output = tf.gather_nd(data, indices=indices)
我设法用gather_nd
做到
nRows = tf.shape(length_label)[0] # ==> ?
nCols = tf.constant(MAX_LENGTH_INPUT + 1, dtype=tf.int32) # ==> 5
m1 = tf.reshape(tf.tile(tf.range(nCols), [nRows]),
shape=[nRows, nCols])
m2 = tf.transpose(tf.reshape(tf.tile(tf.range(nRows), [nCols]),
shape=[nCols, nRows]))
indices = tf.pack([m2, m1, idx], axis=-1)
# indices should be of shape [?, 5, 3] with indices[i,j]==[i,j,idx[i,j]]
output = tf.gather_nd(data, indices=indices)