我使用torch.rot90在pytorch中随机旋转3D图像,但这会以相同的方式旋转批次中的所有图像。我想找到一种可微分的方法,在不同的轴上随机旋转每个图像。这是将每个图像旋转到相同方向的代码:
#x = next batch
k = torch.randint(0, 4, (1,)).item()
dims = [0,0]
dims[0] = dims[1] = torch.randint(2, 5, (1,))
while dims[0] == dims[1]:#this makes sure the two axes aren't the same
dims[1] = torch.randint(2, 5, (1,))
x = torch.rot90(x, k, dims)
# x is now a batch of 3D images that have all been rotated in the same random orientation
您可以将批次中的数据随机拆分为3个子集,并分别应用每个维度的旋转。
让我详细介绍iacob的答案。首先,让我复习一下rot90函数的参数。除了输入张量之外,它还期望k和dims,其中k是要进行的旋转次数,dims是一个列表或元组,包含如何旋转张量的两个维度。例如,如果张量是4D,dims可以是[0,3]或(1,2(或[2,3]等。它们必须是有效的轴,并且应该包含两个数字。你真的不需要为这个参数或k创建张量。需要注意的是,根据给定的调光,输出形状可能会发生巨大变化:
x = torch.rand(15, 3, 4,6)
y1 = torch.rot90(x[0:5], 1, [1,3])
y2 = torch.rot90(x[5:10], 1, [1,2])
y3 = torch.rot90(x[10:15], 1, [2,3])
print(y1.shape) # torch.Size([5, 6, 4, 3])
print(y2.shape) # torch.Size([5, 4, 3, 6])
print(y3.shape) # torch.Size([5, 3, 6, 4])
与iacob的答案类似,这里我们对输入的切片应用3种不同的旋转。请注意,由于在不同维度上旋转的性质,输出维度是不同的。除非你有一个真正特定的输入大小,否则你无法将这些结果真正合并到一个张量中,例如Batch x 10 x 10,其中在1、2、3轴的组合上旋转将始终返回相同的维度。然而,您可以将这些不同大小的输出分别用作不同模块、层等的输入。
我个人想不出可以使用随机轴旋转的用例。如果你能详细说明你为什么要这么做,我可以尝试给出一些更好的解决方案。