PyTorch如何在多个维度上进行聚集

  • 本文关键字:聚集 PyTorch python pytorch
  • 更新时间 :
  • 英文 :


我正试图找到一种方法来做到这一点没有for循环。

假设我有一个多维张量t0:

bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))

形状为:torch.Size([4, 10, 16])

我有另一个张量labels,它是seq维中5个随机索引的一批:

labels = torch.randint(0, seq, size=[bs, sample])

形状是torch.Size([4, 5])。用于索引t0seq维度。

我想做的是在批处理维度上循环使用labels张量进行收集。我的蛮力解决方案是:

t1 = torch.empty((bs, sample, v))
for b in range(bs):
for idx0, idx1 in enumerate(labels[b]):
t1[b, idx0, :] = t0[b, idx1, :]

得到形状为:torch.Size([4, 5, 16])的张量t1

在pytorch中是否有更习惯的方法来做这件事?

您可以在这里使用花哨的索引来选择张量的所需部分。

本质上,如果您事先生成传递访问模式的索引数组,您可以直接使用它们来提取张量的某些切片。每个维度的索引数组的形状应该与你想要提取的输出张量或切片的形状相同。

i = torch.arange(bs).reshape(bs, 1, 1) # shape = [bs, 1,      1]
j = labels.reshape(bs, sample, 1)      # shape = [bs, sample, 1]
k = torch.arange(v)                    # shape = [v, ]
# Get result as
t1 = t0[i, j, k]

注意上面三个张量的形状。广播在张量的前面附加了额外的维度,从而本质上将k形状重塑为[1, 1, v]形状,这使得它们所有3个都兼容于元素操作。

在一起广播(i, j, k)之后将产生3个[bs, sample, v]形状的数组,这些数组将(元素上)索引您的原始张量以产生形状为[bs, sample, v]的输出张量t1

你可以这样做:

t1 = t0[[[b] for b in range(bs)], labels]

t1 = torch.stack([t0[b, labels[b]] for b in range(bs)])

最新更新