如何找到torch张量的多个最大指数



如果我有一个张量,它有多个最大值,我如何获得所有最大值的索引。我试过torch.argmax(张量(,但它只给了我第一个索引。

>>> a_list = [3,23,53,32,53]
>>> a_tensor = torch.Tensor(a_list)
>>> a_tensor
tensor([ 3., 23., 53., 32., 53.])
>>> torch.max(a_tensor)
tensor(53.)
>>> torch.argmax(a_tensor)
tensor(2)

我有以下功能来做这件事,但我想知道是否有更有效的方法:

def max_tensor_indices(tensor_t,max_value):
tensor_list=tensor_t[0]
indices_list=[]
for i in range(len(tensor_list)):
if tensor_list[i]==max_value:
indices_list.append(i)
return indices_list

找到最大值,然后找到具有该值的所有元素。

(x == torch.max(x)).nonzero()

注意:nonzero也可以与as_tuple=True一起调用,这可能会有所帮助。

最新更新