Torch.cat 记忆爆炸



我试图在 ResNet50 上用这个模块替换 Conv2d。

class SubtractedConv(nn.Module):
def __init__(self, input_ch, output_ch, kernels, stride=1):
super().__init__()
self.point_wise = nn.Conv2d(input_ch, output_ch//2, 1, bias=False, stride=stride)
self.depth_wise = nn.Conv2d(output_ch // 2, output_ch // 2, kernels, groups=output_ch // 2, bias=False, padding=kernels // 2)
self.low_pass = nn.Conv2d(output_ch // 2, output_ch // 2, kernels, bias=False, padding=kernels // 2)
def forward(self, x):
p = self.point_wise(x)
d = self.depth_wise(p)
d -= p 
l = self.low_pass(p)
x = torch.cat((d, l), 1)
return x

预期的输出应该与正常的 Conv2d 具有相同的通道,但我在 torch.cat() 中耗尽了内存中的 cuda。 我想知道为什么?以及如何处理这个问题?

我使用 Pytorch 0.4.0 并在特斯拉 P100 上运行,图像大小为 224*224,批量大小为 16。

事实上,这样的东西在 Keras 上有效,它的参数较少(ResNet50 为 16M,而普通 Conv2D 为 25M)。

def subtractedconv(input_tensor, kernel_size, filters, stride=1):
p = kl.Conv2D(filters//2, (1, 1), use_bias=False, strides=stride, padding='same')(input_tensor)
d = DepthwiseConv2D(kernel_size, use_bias=False, padding='same')(p)
d = kl.subtract([d, p])
l = kl.Conv2D(filters//2, kernel_size, use_bias=False, padding='same')(p)
x = kl.Concatenate(axis=-1)([d, l])
return x

PyTorch的问题很可能是创建的中间张量,而不是torch.cat本身。为了通过nn.Conv2d反向传播,你需要将此操作的输入保存在内存中。当您浏览这些层时,内存消耗会增加(保留所有中间结果)。现在在你的forward代码中,你有三个

p = self.point_wise(x) # x is kept
d = self.depth_wise(p) # p is kept
d -= p # here we do not need to keep d, because of derivative formula for subtraction
l = self.low_pass(p)
x = torch.cat((d, l), 1) # but since this goes into further processing, we will need to keep d anyway

请注意,即使您的操作在计算上可能很高效(例如具有小内核),它们仍然需要相同数量的内存来保存输入特征图 - 换句话说,您为每个nn.Conv2d支付大量的固定成本,而不管其自身的复杂性如何。因此,很明显,如果将一个nn.Conv2d替换为三个,则可以预期内存消耗大约增加三倍。

不过,有一个解决方法可以解决您的情况。由于您的整个操作是线性的(您只执行卷积,这是线性的,减法是线性的,串联是线性的),因此您可以将整个计算归结为单个卷积,并带有精心准备的内核。如果我们将卷积视为线性运算符,并用P表示point_wise运算,用D表示depth_wise,用L表示low_pass运算,我们得到你的前向计算concatenate(Dx - Px, LPx),可以简化为[concatenate(D-P, LP)]x。因此,您可以根据三组权重(对应于point_wisedepth_wiselow_pass)预先计算内核,然后调用nn.functional.conv2d一次。但是,实现这种预计算很困难,因为它需要对参数张量的形状进行相当复杂的转换,以保留操作的精确语义(例如,从 1x1 核P中减去空间核D)。我尝试在 10 分钟内得到这个并失败了,但如果这非常重要,请考虑在 PyTorch 论坛上询问或在评论中告诉我。

至于为什么Keras会处理它,我不确定,但一个强烈的猜测是,这要归功于TensorFlow。与PyTorch不同,TensorFlow(主要)使用静态计算图,可以提前分析和优化。我希望 TensorFlow 能够识别三个线性运算符的序列并将它们组合成一个,或者至少部分利用它们的线性来优化内存使用。

最新更新