遮罩窗口外部的尺寸标注

  • 本文关键字:窗口 外部 pytorch
  • 更新时间 :
  • 英文 :


我有一个形状为(n, x, y)的Pytorch张量t,我想应用一个掩码,以便对于所有y > x + k(其中k是常数(,t[n, x, y] = -inf

我相信我可以用高级索引来做到这一点,但不知道怎么做。

如果不是,一个简单的方法是构造一个类似于循环的掩码(速度较慢,但只做一次并缓存(,然后是t += mask,因为-inf + z == -inf适用于所有z

有更好的方法吗?

请注意,条件y ≥ x对应于上三角形,而y > x是严格的上三角形。因此CCD_ 11是偏移等于CCD_ 12的上三角形部分。

您可以使用torch.triu构造一个三角形掩码,它实际上允许名为diagonal移位参数,指的是对角线的位置。指定所需的值,此处为-torch.inf,使用此掩码可以获得所需的结果。

总的来说,它可以归结为:

>>> m = torch.ones_like(t, dtype=bool).triu(1+k)
>>> t[m] = -torch.inf

或者,使用torch.where:也可以使用一个衬垫

>>> torch.where(torch.ones_like(t).bool().triu(1+k), -torch.inf, t)

由于掩码对所有批处理元素都是相等的,因此您可以创建单个2D掩码并在其第二和第三轴上屏蔽t

>>> m = torch.ones_like(t[0], dtype=bool).triu(1+k)
>>> t[:,m] = -torch.inf

最新更新