model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
args.dataset == 'cifar100' or args.dataset == 'cifar10':
args.stride = [2, 2]
resnet = resnet18(args, pretrained=False, num_classes=args.num_classes)
initial_weight = model_zoo.load_url(model_urls['resnet18'])
local_model = resnet
initial_weight_1 = local_model.state_dict()
for key in initial_weight.keys():
if key[0:3] == 'fc.' or key[0:5]=='conv1' or key[0:3]=='bn1':
initial_weight[key] = initial_weight_1[key]
local_model.load_state_dict(initial_weight)
我不明白这句话;initial_ weight[key]=initial_;
你能告诉我为什么我们需要这样做吗?
感谢
函数torch.utils.model_zoo.load_url
将从给定的URL加载序列化的torch对象。在这种特殊情况下,使用的URL承载ResNet18网络的模型权重字典。
因此,initial_weight
是包含预训练的ResNet18的权重的字典,而initial_weight_1
是由resnet18
初始化的存储器中的当前模型resnet
的权重的词典。
如果满足key[0:3] == 'fc.' or key[0:5]=='conv1' or key[0:3]=='bn1':
条件,以下几行将遍历resnet
模型的各个层,并复制从该URL加载的权重。