Fine tuning resnet18 for cifar10



我只想在cifar10数据集上微调ResNet18。所以我只想把最后一个线性层从1000改为10。我试着使用children函数来获得之前的层

ResModel = resnet18(weights=ResNet18_Weights)
model = nn.Sequential(
*list(ResModel.children())[:-1],
nn.Linear(512,10)
)

所以它引发了错误RuntimeError: mat1 and mat2 shapes cannot be multiplied (32768x1 and 512x10)然后我尝试了这种方式ResModel.fc=nn.Linear(512,10),它运行良好。那为什么呢?

将所有层堆叠到单个nn.Sequential和仅覆盖最后一层之间的区别在于forward函数:
您的ResModel是类型torchvision.models.ResNet,而您的model是简单的nn.SequentialResNetforward过程在最后一个线性层之前有一个额外的flatten操作——您的nn.Sequentialmodel中没有此操作。

最新更新