我有一个形状为(n, x, y)
的Pytorch张量t
,我想应用一个掩码,以便对于所有y > x + k
(其中k
是常数(,t[n, x, y] = -inf
。
我相信我可以用高级索引来做到这一点,但不知道怎么做。
如果不是,一个简单的方法是构造一个类似于循环的掩码(速度较慢,但只做一次并缓存(,然后是t += mask
,因为-inf + z == -inf
适用于所有z
。
有更好的方法吗?
请注意,条件y ≥ x
对应于上三角形,而y > x
是严格的上三角形。因此CCD_ 11是偏移等于CCD_ 12的上三角形部分。
您可以使用torch.triu
构造一个三角形掩码,它实际上允许名为diagonal
的移位参数,指的是对角线的位置。指定所需的值,此处为-torch.inf
,使用此掩码可以获得所需的结果。
总的来说,它可以归结为:
>>> m = torch.ones_like(t, dtype=bool).triu(1+k)
>>> t[m] = -torch.inf
或者,使用torch.where
:也可以使用一个衬垫
>>> torch.where(torch.ones_like(t).bool().triu(1+k), -torch.inf, t)
由于掩码对所有批处理元素都是相等的,因此您可以创建单个2D掩码并在其第二和第三轴上屏蔽t
:
>>> m = torch.ones_like(t[0], dtype=bool).triu(1+k)
>>> t[:,m] = -torch.inf