我正在尝试使用tf.gather_nd进行转换
'R = tf.eye(3, batch_shape=[4])'
自:
array([[[1., 0., 0.],
[0., 0., 1.],
[0., 1., 0.]],
[[0., 0., 1.],
[0., 1., 0.],
[1., 0., 0.]],
[[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.]],
[[1., 0., 0.],
[0., 0., 1.],
[0., 1., 0.]]], dtype=float32)'
使用索引:
ind = array([[0, 2, 1],
[2, 1, 0],
[1, 2, 0],
[0, 2, 1]], dtype=int32)
我发现是否可以将索引矩阵转换为以下内容:
ind_c = np.array([[[0, 0], [0, 2], [0, 1]],
[[1, 2], [1, 1], [1, 0]],
[[2, 1], [2, 2], [2, 0]],
[[3, 0], [3, 2], [3, 1]]])
gather_nd会完成这项工作。 所以我的问题是:
- 有没有比将索引 IND 转换为 ind_c 更好的方法
- 如果这是我如何使用 TensorFlow 将 IND 转换为 ind_c 的唯一方法?(我现在手动完成此操作(
谢谢
您可以尝试以下操作:
ind = tf.constant([[0, 2, 1],[2, 1, 0],[1, 2, 0],[0, 2, 1]], dtype=tf.int32)
# Creates the row indices matrix
row = tf.tile(tf.expand_dims(tf.range(tf.shape(ind)[0]), 1), [1, tf.shape(ind)[1]])
# Concat to the ind to form the index matrix
ind_c = tf.concat([tf.expand_dims(row,-1), tf.expand_dims(ind, -1)], axis=2)