TensorFlow 中的 index() 函数



有没有办法在张量流中获取张量中值的索引?

例如,我有一个独热矩阵,我想获取 1 的坐标。

| 0 0 0 |
| 0 1 0 |  => (1,1)
| 0 0 0 |

最好通过一个函数来完成,例如 tensor.index(binary_function)

您可以使用tf.where .

例如

import tensorflow as tf
x = tf.constant([[0, 0, 0],
                 [0, 1, 0],
                 [0, 3, 0]])
with tf.Session() as sess:
    coordinates = tf.where(tf.greater(x, 0))
    print(coordinates.eval()) # [[1 1], [2 1]]
    print(tf.gather_nd(x, coordinates).eval()) # [1, 3]

最新更新