为什么tf.argmax()返回错误的索引



我正在尝试使用tf.argmax()函数获取logits的最大索引。我的代码如下所示:

import tensorflow as tf
import numpy as np

logits = tf.random_uniform([1,3,3,21], maxval=255, dtype=tf.float32, seed=0)
logits = logits / tf.norm(logits)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
logits_eval = sess.run(logits)
logits_argmax_np = np.argmax(logits_eval, axis=-1)
ypredT = tf.argmax(logits, axis=-1)
logits_argmax_tf = ypredT.eval()

我可以使用np.argmax()获得正确的索引,但我不知道为什么tf.argmax()返回错误的索引
非常感谢。

编辑:我正在使用tensorflow 1.13

这里的问题是,每次调用sess.run时,都会从下到上单独执行会话。由于生成了随机数,因此在每次运行中不会产生相同的结果,因此每次运行的argmax不同。但他们也在做同样的事情。

要看到这一点,您可以从同一会话执行中获得两个argmax,使用方括号从同一个sess.run:获得ypredT张量和logits张量

# tensorflow graph
logits = tf.random_uniform([1,3,3,21], maxval=255, dtype=tf.float32, seed=0) 
logits = logits / tf.norm(logits)
ypredT = tf.argmax(logits, axis=-1) # tensorflow argmax
# run session 
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
logits_eval, logits_argmax_tf = sess.run([logits, ypredT])
# after session has closed
logits_argmax_np = np.argmax(logits_eval, axis=-1) # get numpy argmax
print(logits_argmax_np)
print(logits_argmax_tf)

输出:

[[[14  0  4]
[ 0  8  0]
[10 12  3]]]
[[[14  0  4]
[ 0  8  0]
[10 12  3]]]

最新更新