张量流收集、连接然后填充操作?



我在TensorFlow 2(python)中有一个2D张量。如何根据行索引的参差不齐数组挑选和连接行,然后用零填充较短的行,以便所有行的长度都相同?

这是我拥有的示例

data = tf.constant([
[300, 301, 302],
[100, 101, 102],
[200, 201, 202],
[120, 121, 122],
[210, 211, 212],
[410, 411, 412],
[110, 111, 112],
[400, 401, 402],
], dtype=tf.float32)
row_ids = [ [ 1, 6, 3 ], [ 2, 4 ], [ 0 ], [ 7, 5] ]

这就是我想要得到的

desired_result = tf.constant([
[ 100, 101, 102, 110, 111, 112, 120, 121, 122],
[ 200, 201, 202, 210, 211, 212,   0,   0,   0],
[ 300, 301, 302,   0,   0,   0,   0,   0,   0],
[ 400, 401, 402, 410, 411, 412,   0,   0,   0]
], 
dtype=tf.float32
)

我试图找到一种方法与tf.RaggedTensor.from_value_rowids()tf.gather_nd()tf.concat(),但没有任何成功。

我确实需要通过此操作进行反向传播,因此,我需要坚持使用 TensorFlow 2 操作。

任何建议将不胜感激!谢谢!

IIUC,您实际上可以更简单地解决此任务:

import tensorflow as tf
data = tf.constant([
[300, 301, 302],
[100, 101, 102],
[200, 201, 202],
[120, 121, 122],
[210, 211, 212],
[410, 411, 412],
[110, 111, 112],
[400, 401, 402],
], dtype=tf.float32)
row_ids = tf.ragged.constant([ [ 1, 6, 3 ], [ 2, 4 ], [ 0 ], [ 7, 5] ])
t = tf.gather(data, row_ids).to_tensor()
t = tf.reshape(t, [tf.shape(t)[0], tf.reduce_prod(tf.shape(t)[1:])])
<tf.Tensor: shape=(4, 9), dtype=float32, numpy=
array([[100., 101., 102., 110., 111., 112., 120., 121., 122.],
[200., 201., 202., 210., 211., 212.,   0.,   0.,   0.],
[300., 301., 302.,   0.,   0.,   0.,   0.,   0.,   0.],
[400., 401., 402., 410., 411., 412.,   0.,   0.,   0.]],
dtype=float32)>

我想我已经找到了一个对我和希望对其他人有用的解决方案。

这个想法是:

  1. 向原始数据添加"填充行">
  2. 使用填充行号扩展较短的索引数组
  3. 使用 tf.gather_nd() 挑选行
  4. 重塑结果以连接内部尺寸

这是代码:

# Add pad row
pad_row = tf.zeros(shape=[1, 3], dtype=tf.float32)
data_with_pad_row = tf.concat([data, pad_row], axis=0)
pad_row_no = data_with_pad_row.shape[0] - 1
# Extend indices
max_row_per_row = max([ len(rows_ids) for rows_ids in row_ids ])
new_row_ids = [ rows_ids + [ pad_row_no]*(max_row_per_row-len(rows_ids)) for rows_ids in row_ids ]
new_row_ids = [ [ [ row_id ] for row_id in rows_ids ] for rows_ids in new_row_ids ]
# Gather and reshape
g3d = tf.gather_nd(indices=new_row_ids, params=data_with_pad_row)
result = tf.reshape(g3d, [g3d.shape[0], g3d.shape[1]*g3d.shape[2]])

这将获得所需的结果,并允许通过操作进行反向传播。

最新更新