我如何消除这个pytorch代码中的for循环



我有这个自定义pytorch模块(如下)。这正是我所需要的;只是速度很慢。我怎么做才能加快速度?我知道这里不应该有for循环;只是不清楚没有它怎么做除法运算。我怎么把x张量广播给除法而不用那个循环?如果有帮助的话,我可以把backweight移到它们自己的图层。

class StepLayer(nn.Module):
def __init__(self):
super(StepLayer, self).__init__()
w = init_weights()
self.front_weights = nn.Parameter(torch.DoubleTensor([w, w]).T, requires_grad=True)
self.back_weights = nn.Parameter(torch.DoubleTensor([w]).T, requires_grad=True)

def forward(self, x):
# x shape is batch by feature
results = []
for batch in x:
b = batch.divide(self.front_weights)
b = torch.some_math_function(b)
b = b.sum(dim=1)
b = torch.some_other_math_function(b)
b = b @ self.back_weights
results.append(b)
stack = torch.vstack(results)
return stack

下面是一个源代码,每个步骤后面都有形状(请阅读代码注释)。

我已经假设了一些东西,如F=100,x=Bx2,front_weights=100x2,back_weights=100,你应该能够轻松地调整它到你的情况。

class StepLayer(nn.Module):
def __init__(self):
super().__init__()
F = 100
# Notice I've added `1` dimension in front_weights
self.front_weights = nn.Parameter(torch.randn(1, F, 2), requires_grad=True)
self.back_weights = nn.Parameter(torch.randn(F), requires_grad=True)
def forward(self, x):
# x.shape == (B, 2)
x = x.unsqueeze(dim=1)  # (B, 1, 2)
x = x / self.front_weights  # (B, F, 2)
# I just took some element-wise math function from PyTorch
x = torch.sin(x)  # (B, F, 2)
x = torch.sum(x, dim=-1)  # (B, F)
x = torch.sin(x)  # (B, F)
return x @ self.back_weights  # (B, )
# results = []
# for batch in x:
#     # batch - (1, 2)
#     b = batch.divide(self.front_weights)  # (F, 2)
#     b = torch.some_math_function(b)  # (F, 2)
#     b = b.sum(dim=1)  # (F, )
#     b = torch.some_other_math_function(b)  # (F, )
#     b = b @ self.back_weights  # (1, )
#     results.append(b)
# stack = torch.vstack(results)  # (B, )
# return stack  # (B,)

layer = StepLayer()
print(layer(torch.randn(64, 2)).shape)

主要技巧是在必要时使用1维度进行广播(特别是除法)和智能权重初始化,因此您不必进行任何转置操作。

其他事情

  • 你可能要重新考虑Double,float(如上所述)更快,特别是在CUDA上,占用一半的内存(神经网络应该补偿精度损失,如果有的话)。
  • 使用half精度和混合训练,如果速度仍然是一个问题(float16dtype而不是float32),但仅在CUDA上,点击这里了解更多关于自动混合精度的信息

相关内容

  • 没有找到相关文章

最新更新