Pytorch BatchNorm2d 运行时错误:running_mean 应包含 64 个元素,而不是 0



我正在使用八度卷积,并设置了一个 BatchNorm2d 改编,这对我来说是

RuntimeError: running_mean should contain 64 elements not 0

我已经设置了一些调试打印来检查我的张量尺寸出了什么问题,但找不到它。 这是我的班级:

class _BatchNorm2d(nn.Module):
def __init__(self, num_features, alpha_in=0, alpha_out=0, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(_BatchNorm2d, self).__init__()
hf_ch = int(num_features * (1 - alpha_out))
lf_ch = num_features - hf_ch
self.bnh = nn.BatchNorm2d(hf_ch)
self.bnl = nn.BatchNorm2d(lf_ch)
def forward(self, x):
if isinstance(x, tuple):
hf, lf = x
print("IN ON BN: ",lf.shape if lf is not None else None) #DEBUGGING PRINT
print(self.bnl)  #DEBUGGING PRINT
hf = self.bnh(hf) if type(hf) == torch.Tensor else hf
lf = self.bnh(lf) if type(lf) == torch.Tensor else lf #THIS IS THE LINE ACCUSING THE ERROR
print("ENDED BN")
return hf, lf
else:
return self.bnh(x)

这是打印错误:

IN ON BN:  torch.Size([32, 64, 3, 3])
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

在我看来,该功能应该已经起作用,因为 x 有 64 个通道,而 bn 期望 64 个通道。

编辑: 可能还需要提到错误只发生在 alpha 值 1 上。但是,我不明白,因为卷仍然相同。

已解决。这是低频BN通话的错字。

hf = self.bnh(hf) if type(hf) == torch.Tensor else hf
lf = self.bnh(lf) if type(lf) == torch.Tensor else lf

应该是

hf = self.bnh(hf) if type(hf) == torch.Tensor else hf
lf = self.bnl(lf) if type(lf) == torch.Tensor else lf

最新更新