我正试图找到一种方法来做到这一点没有for循环。
假设我有一个多维张量t0
:
bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))
形状为:torch.Size([4, 10, 16])
我有另一个张量labels
,它是seq
维中5个随机索引的一批:
labels = torch.randint(0, seq, size=[bs, sample])
形状是torch.Size([4, 5])
。用于索引t0
的seq
维度。
我想做的是在批处理维度上循环使用labels
张量进行收集。我的蛮力解决方案是:
t1 = torch.empty((bs, sample, v))
for b in range(bs):
for idx0, idx1 in enumerate(labels[b]):
t1[b, idx0, :] = t0[b, idx1, :]
得到形状为:torch.Size([4, 5, 16])
的张量t1
在pytorch中是否有更习惯的方法来做这件事?
您可以在这里使用花哨的索引来选择张量的所需部分。
本质上,如果您事先生成传递访问模式的索引数组,您可以直接使用它们来提取张量的某些切片。每个维度的索引数组的形状应该与你想要提取的输出张量或切片的形状相同。
i = torch.arange(bs).reshape(bs, 1, 1) # shape = [bs, 1, 1]
j = labels.reshape(bs, sample, 1) # shape = [bs, sample, 1]
k = torch.arange(v) # shape = [v, ]
# Get result as
t1 = t0[i, j, k]
注意上面三个张量的形状。广播在张量的前面附加了额外的维度,从而本质上将k
形状重塑为[1, 1, v]
形状,这使得它们所有3个都兼容于元素操作。
在一起广播(i, j, k)
之后将产生3个[bs, sample, v]
形状的数组,这些数组将(元素上)索引您的原始张量以产生形状为[bs, sample, v]
的输出张量t1
。
你可以这样做:
t1 = t0[[[b] for b in range(bs)], labels]
或
t1 = torch.stack([t0[b, labels[b]] for b in range(bs)])