如何自己实现tf.argmax?



我想使用一个函数,该函数将张量作为输入,并返回跨张量轴的最大值的索引。我知道有一个函数 tf.argmax(( 可以执行完全相同的操作,但是我如何自己实现它(在实现某些自定义函数的情况下,这可能是必要的(?

现在让我们假设该函数仅将 1D 张量作为输入。因此,该函数需要具有以下签名:

argmax(
input, #input is a 1D tensor
name=None
)

我尝试以这种方式实现它:

def argmax(input, name=None):
maxValue=0
maxIndex=0
for i in range(input.get_shape()[0]):
if input[i]>maxValue:
maxValue=input[i]
maxIndex=i
return maxIndex

但是,这不起作用,因为在构建阶段,这些值尚未初始化,因此我无法像在上面的代码中那样比较两个值。那么,有没有办法写出自定义函数,如tf.argmax,tf.equal等?

嗯,一个简单的方法是这样的:

idx = tf.where(tf.equal(input, tf.reduce_max(input)))[0, 0]

例:

import tensorflow as tf
with tf.Session() as sess:
input = tf.constant([1, 3, 4, 2, 1, 2])
idx = tf.where(tf.equal(input, tf.reduce_max(input)))[0, 0]
print(sess.run(idx))

输出:

2

最新更新