Pytorch 网络参数计算



有人可以告诉我网络参数(10(是如何计算的吗?提前谢谢。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
net = Net()
print(net)
print(len(list(net.parameters())))

输出:

Net(
  (conv1): Conv2d (1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d (6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120)
  (fc2): Linear(in_features=120, out_features=84)
  (fc3): Linear(in_features=84, out_features=10)
)
10

最好Zack

PyTorch 中的大多数层模块(例如 Linear、Conv2d 等(将参数分组到特定类别中,例如权重和偏差。网络中的五个层实例中的每一个都有一个"权重"和一个"偏差"参数。这就是打印"10"的原因。

当然,所有这些"权重"和"偏差"字段都包含许多参数。例如,您的第一个完全连接的图层self.fc1包含16 * 5 * 5 * 120 = 48000参数。因此,len(params)不会告诉您网络中的参数数量 - 它只为您提供网络中参数"分组"的总数。

由于 Bill 已经回答了为什么打印"10",我只是分享一个代码片段,您可以使用它来找出与网络中每一层相关的参数数量。

def count_parameters(model):
    total_param = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_param = numpy.prod(param.size())
            if param.dim() > 1:
                print(name, ':', 'x'.join(str(x) for x in list(param.size())), '=', num_param)
            else:
                print(name, ':', num_param)
            total_param += num_param
    return total_param

按如下方式使用上述函数。

print('number of trainable parameters =', count_parameters(net))

输出:

conv1.weight : 6x1x5x5 = 150
conv1.bias : 6
conv2.weight : 16x6x5x5 = 2400
conv2.bias : 16
fc1.weight : 120x400 = 48000
fc1.bias : 120
fc2.weight : 84x120 = 10080
fc2.bias : 84
fc3.weight : 10x84 = 840
fc3.bias : 10
number of trainable parameters = 61706

最新更新