在tensorflow中,我试图将2D数组的每一行中的最大值替换为1,所有其他数字都为零,就像这样;
x = tf.constant(
[[0, 4],
[2,3],
[6,7],
[9,2]])
output should be like this;
[[0, 1],
[0,1],
[0,1],
[1,0]])
到目前为止,我已经能够通过使用以下代码将最大值变为1,所有其他值变为零。
import tensorflow as tf
x = tf.constant(
[[0, 4],
[2,3],
[6,7],
[9,2]])
top_values, top_indices = tf.nn.top_k(tf.reshape(x, (-1,)), 1)
output = tf.cast(tf.greater_equal(x, top_values), tf.float64)
Output: ([[0. 0.]
[0. 0.]
[0. 0.]
[1. 0.]])
但是我如何在每个数组中分别获得最大值=1。有人能帮我解决这个问题吗?
你可以用矢量化的方式来代替对x
的重塑:
top_vals, _ = tf.nn.top_k(x, 1)
output = tf.cast(tf.greater_equal(x, top_vals), tf.float64)
output
输出:
<tf.Tensor: shape=(4, 2), dtype=float64, numpy=
array([[0., 1.],
[0., 1.],
[0., 1.],
[1., 0.]])>