PyTorch中的图像特征提取



为了理解这个代码片段,我遇到了很多困难。

import torch
import torch.nn as nn
import torchvision.models as models
def ResNet152(out_features = 10):
return getattr(models, "resnet152")(pretrained=False, num_classes = out_features)
def VGG(out_features = 10):
return getattr(models, "vgg19")(pretrained=False, num_classes = out_features)

在该代码段中,通过ResNet152和Vgg19模型提取输入图像的特征。但我有一个问题,是从这些模型的哪一部分提取特征,是最后一个池化层还是分类层之前的层,还是其他什么。

注意,getattr(models, 'resnet152')等价于models.resent152

因此,下面的代码将返回模型本身。

getattr(models, "resnet152")(pretrained=False, num_classes = out_features)
# is same as
models.resnet152(pretrained=False, num_classes = out_features)

现在,如果你只需打印模型的结构,最后一层是完全连接的层,所以这就是你在这里得到的功能。

print(ResNet152())
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
...
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=2048, out_features=10, bias=True)
)

VGG()的情况也是如此。

最新更新