>我有一个形状为(100, 16, 16)
的张量 A 和形状为(100)
的张量 B,其中 100 是批量大小。我想创建一个具有形状(100, 16, 16)
的 A 二进制掩码,其中在每个元素(元素具有形状(1, 16, 16)
)的掩码中,如果元素大于计算的分位数值,则值1
,否则0
。张量 B 中的每个元素按顺序指示 A 中每个元素的百分位值。如果 B 只是一个标量,我可以使用:
flat_A = torch.reshape(A, (100, -1))
quants = torch.quantile(flat_A, B, dim=1)
quants = torch.reshape(quants, (100, 1, 1))
mask = torch.where(A >= quants, 1, 0)
# quants will have shape (100, 1, 1)
问题是:如果 B 是形状(100)
的一维张量,就像我上面说的,我如何计算 A 中每个元素的百分位值?我尝试了以下方法,但结果看起来与我预期的不同:
>>> torch.quantile(flat_A, B, dim=1).shape
torch.Size([100, 100])
>>> torch.quantile(flat_A, B, dim=0).shape
torch.Size([100, 256])
我认为结果的形状应该是(100)
的,所以我可以使用mask = torch.where(A >= quants, 1, 0)
,或者我误解了它?
对于更多上下文,这个问题也是我之前在这里提出的标量 B 值问题的扩展。
这是使用torch.quantile()
函数的一种方式。请注意,为了简单起见,这里我使用的是形状为 (5, 2, 2) 而不是(100, 16, 16)的张量。
import torch
# Generate some data of shape (5, 2, 2)
A = torch.arange(5 * 2 * 2).reshape(5, 2, 2) + 1.0
B = torch.linspace(0, 1, 5) # 5 quantile values for each element in A
Af = A.reshape(A.shape[0], -1) # flattens A to a 2D tensor
quantiles = torch.quantile(Af, B, dim = 1, keepdim = True)
quants = quantiles[torch.arange(A.shape[0]), torch.arange(A.shape[0]), 0]
mask = (A >= quants[:, None, None]).type(torch.uint8)
这里的张量quantiles
的形状torch.Size([5, 5, 1])
,因为它以B
形式存储每个元素的阈值,A
(或Af
中的行)。由于我们有 5 个分位数值,因此我们得到A
中每个元素的 5 个阈值。
例如,quantiles[i, j, 0]
具有A[j]
或Af[j]
的第分位数B[i]
阈值,并且您基本上需要为批量大小或 5 范围内的k
quantiles[k, k, 0]
值。
现在,为了满足需要B
中相应分位数和A
元素阈值的要求,只需从quantiles
索引出对角线元素并填充具有形状torch.Size([5])
的quants
即可。
最后得到mask
,将A
与每个元素的相应阈值进行比较。请注意,这使用与阈值的广播元素比较。mask
具有所需的torch.Size([5, 2, 2])
形状。