在Pytorch Argmax发生冲突的情况下,索引选择



我一直在尝试学习张量操作,这使我陷入了循环。
让我们说我有一个张量:

    t = torch.tensor([
        [1,0,0,2],
        [0,3,3,0],
        [4,0,0,5]
    ], dtype  = torch.float32)

现在这是一个排名2张量,我们可以为每个等级/维度应用Argmax。让我们说我们将其应用于DIM = 1

t.max(dim = 1)
(tensor([2., 3., 5.]), tensor([3, 2, 3]))

现在,我们可以看到结果与沿dim = 1的张量为2,3,而最大元素为5。但是3.有两个完全相似的值。
如何解决?是任意选择的吗?是否有选择L-R的订单,更高的索引值?
我很感谢您对如何解决的任何见解!

这是我自己偶然发现了几次的好问题。最简单的答案是,无法保证torch.argmax(或torch.max(x, dim=k),在指定DIM时也返回索引)将始终返回相同的索引。相反,它可能会随机返回任何有效的索引。正如官方论坛讨论的那个线程一样,这被认为是理想的行为。(我知道有一段时间我读过的另一个线程使它更加明确,但我找不到它)。

说过,由于这种行为对我的用例无法接受,我编写了以下功能,这些功能将找到左右索引(请注意condition是您传递的功能对象):

def __consistent_args(input, condition, indices):
    assert len(input.shape) == 2, 'only works for batch x dim tensors along the dim axis'
    mask = condition(input).float() * indices.unsqueeze(0).expand_as(input)
    return torch.argmax(mask, dim=1)

def consistent_find_leftmost(input, condition):
    indices = torch.arange(input.size(1), 0, -1, dtype=torch.float, device=input.device)
    return __consistent_args(input, condition, indices)

def consistent_find_rightmost(input, condition):
    indices = torch.arange(0, input.size(1), 1, dtype=torch.float, device=input.device)
    return __consistent_args(input, condition, indices)
# one example:
consistent_find_leftmost(torch.arange(10).unsqueeze(0), lambda x: x>5)                                                                                                                                     
# will return: 
# tensor([6])

希望他们会有所帮助!(哦,如果您有更好的实现,请告诉我)

相关内容

最新更新