如果我有一个张量,它有多个最大值,我如何获得所有最大值的索引。我试过torch.argmax(张量(,但它只给了我第一个索引。
>>> a_list = [3,23,53,32,53]
>>> a_tensor = torch.Tensor(a_list)
>>> a_tensor
tensor([ 3., 23., 53., 32., 53.])
>>> torch.max(a_tensor)
tensor(53.)
>>> torch.argmax(a_tensor)
tensor(2)
我有以下功能来做这件事,但我想知道是否有更有效的方法:
def max_tensor_indices(tensor_t,max_value):
tensor_list=tensor_t[0]
indices_list=[]
for i in range(len(tensor_list)):
if tensor_list[i]==max_value:
indices_list.append(i)
return indices_list
找到最大值,然后找到具有该值的所有元素。
(x == torch.max(x)).nonzero()
注意:nonzero
也可以与as_tuple=True
一起调用,这可能会有所帮助。