"model.parameters()"包括什么?



在 Pytorch 中,

什么将被注册到model.parameters()中。

就目前而言,我所知道的如下:

1.  Conv layer: weight  bias
2.  BN layers: weight(gamma)  bias(beta)
3.  nn.Parameter() 
such as:   self.alpha = nn.Parameter(torch.rand(10))  defined in the model.

我的问题是model.parameters()中是否注册了一些其他参数?

PS.model.parameters()最常见的情况是在优化器中, 例如 PyTorch resnet 示例

optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)

提前谢谢你。

就像你在那里写的那样,model.parameters()存储模型的权重和偏差(如果设置为 true(。 它作为参数提供给优化器,以用一行代码optimizer.step()更新模型的权重和偏差值,然后在下次浏览数据集时使用。

相关内容

最新更新