我只想在cifar10数据集上微调ResNet18。所以我只想把最后一个线性层从1000改为10。我试着使用children
函数来获得之前的层
ResModel = resnet18(weights=ResNet18_Weights)
model = nn.Sequential(
*list(ResModel.children())[:-1],
nn.Linear(512,10)
)
所以它引发了错误RuntimeError: mat1 and mat2 shapes cannot be multiplied (32768x1 and 512x10)
然后我尝试了这种方式ResModel.fc=nn.Linear(512,10)
,它运行良好。那为什么呢?
将所有层堆叠到单个nn.Sequential
和仅覆盖最后一层之间的区别在于forward
函数:
您的ResModel
是类型torchvision.models.ResNet
,而您的model
是简单的nn.Sequential
。ResNet
的forward
过程在最后一个线性层之前有一个额外的flatten
操作——您的nn.Sequential
model
中没有此操作。