我有第一个张量,大小为torch.Size([12, 64, 8, 8, 3])
,(8,8,3(是图像大小,64是补丁,12是批量大小的
还有另一个大小为torch.Size([12, 10])
的张量,它为批次中的每个项目选择10个补丁(从总共64个补丁中选择10个(。因此它存储索引。如何使用它来查询具有列表理解的第一个张量?
目标是查询每个批次的第一个补丁(从选定的10个补丁中(。当在b
上迭代时,我们得到所选补丁索引的列表。按索引0从中选择第一个。由于它们是张量,因此将类型转换为int
,以便索引到图像的张量中,并检索每个批次的相应补丁。
a = torch.rand(12, 64, 8, 8, 3) # generating 12 batches, with 64 patches,each of size 8x8x3
b = torch.randint(64, (12, 10)) # choosing 10 patches (within the 64), for each of the 12 batches
first_tensors = [a[batch, int(patches[0])] for batch, patches in zip(range(12), b)]
为了清楚起见,下面的列表理解将给出每批第一个补丁的索引。
[[batch, int(patches[0])] for batch, patches in zip(range(12), b)]
[[0, 40],
[1, 27],
[2, 17],
[3, 62],
[4, 9],
[5, 51],
[6, 32],
[7, 38],
[8, 63],
[9, 10],
[10, 2],
[11, 6]]
用上面列表中的每对索引对图像的张量a
进行索引将给出相应的补丁。
您可以使用index_select
:
c = [torch.index_select(i, dim=0, index=j) for i, j in zip(a,b)]
a
和b
分别是你的张量和指数。
之后可以在零维中stack
它。