这个numpy数组运算的等价tensorflow运算是什么



例如,y_pred是一个numpy数组,我想要这个操作。

result = []
for i in y_pred:
i = np.where(i == i.max(), 1, 0)
result.append(i)

y_pred中每行的最大数目将变为1,其余的将变为0

如果y_pred是张量,我如何实现这个运算?

不确定是否有一个矢量化函数可以同时完成所有这些任务。

假设这是您的数组。

a = tf.Variable([[10,11,12,13],
[13,14,16,15],
[18,16,17,15],
[0,4,3,2]], name='a')

CCD_ 5给出了每一行中最大值的索引。现在我们从onesacopy = tf.Variable(tf.zeros((4,4),tf.int32), name='acopy')开始,它与原始数组的形状相同。

之后,我们只需在找到最大值的位置替换零后,将行缝合在一起。

import tensorflow as tf
a = tf.Variable([[10,11,12,13],
[13,14,16,15],
[18,16,17,15],
[0,4,3,2]], name='a')
acopy = tf.Variable(tf.zeros((4,4),tf.int32), name='acopy')
top = tf.nn.top_k(a,1)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
values,indices = sess.run(top)
shapeofa = a.eval(sess).shape
print(indices)
sess.run(tf.global_variables_initializer())
for i in range(0, shapeofa[0] ) :
oldrow = tf.gather(acopy, i)
index = tf.squeeze(indices[i])
b = sess.run( tf.equal(index, 0) )
if not b  :
o = oldrow[0:index]
newrow = tf.concat([o, tf.constant([1])], axis=0)
c = sess.run(tf.equal(index, (tf.size(oldrow) - 1)))
if not c:
newrow = tf.concat([newrow, oldrow[(index+1):(tf.size(oldrow))]], axis=0)
else :
o1 = oldrow[1:tf.size(oldrow)]
newrow = tf.concat([tf.constant([1]), o1], axis=0)
acopy = tf.scatter_update(acopy, i, newrow)
print(sess.run(acopy))

输出是这样的。

[[0 0 1]

[0 0 1 0]

[1 0 0 0]

[0 1 0]]

您可以自己测试。

最新更新