我正在尝试删除efficientnet pytorch实现中的顶层。然而,如果我像作者在github评论中建议的那样,简单地用我自己的完全连接层替换最后的_fc
层,我担心即使在这个层之后仍然有swish
激活,而不是像我预期的那样什么都没有。当我打印模型时,最后几行如下:
(_bn1): BatchNorm2d(1280, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
(_avg_pooling): AdaptiveAvgPool2d(output_size=1)
(_dropout): Dropout(p=0.2, inplace=False)
(_fc): Sequential(
(0): Linear(in_features=1280, out_features=512, bias=True)
(1): ReLU()
(2): Dropout(p=0.25, inplace=False)
(3): Linear(in_features=512, out_features=128, bias=True)
(4): ReLU()
(5): Dropout(p=0.25, inplace=False)
(6): Linear(in_features=128, out_features=1, bias=True)
)
(_swish): MemoryEfficientSwish()
)
)
其中_fc
是我替换的模块。
我希望做的是:
base_model = EfficientNet.from_pretrained('efficientnet-b3')
model = nn.Sequential(*list(base_model.children()[:-3]))
在我看来CCD_ 4从嵌套结构中使模型变平。然而,现在我似乎无法像使用伪输入一样使用模型,x=torch.randn(1,3,255,255)
我得到错误:TypeError: forward() takes 1 positional argument but 2 were given
。
需要注意的是,model[:2](x)
有效,但model[:3](x)
无效。CCD_ 9似乎是移动块。
这是一个带有上述代码的colab笔记本。
这是对print(net)
实际功能的常见误解。
事实上,在_fc
之后有一个_swish
模块,这仅仅意味着前者在后者之后注册。你可以在代码中检查:
class EfficientNet(nn.Module):
def __init__(self, blocks_args=None, global_params=None):
# [...]
# Final linear layer
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
self._dropout = nn.Dropout(self._global_params.dropout_rate)
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
self._swish = MemoryEfficientSwish()
定义它们的顺序就是打印它们的顺序。当涉及到具体执行的内容时,您必须检查forward
:
def forward(self, inputs):
# Convolution layers
x = self.extract_features(inputs)
# Pooling and final linear layer
x = self._avg_pooling(x)
x = x.flatten(start_dim=1)
x = self._dropout(x)
x = self._fc(x)
return x
并且,正如您所看到的,在self._fc(x)
之后没有任何内容,这意味着不会应用Swish
。