我可以使用
创建一个FasterRCNN
对象model = fasterrcnn_resnet50_fpn(...)
我想继承的,如
class MyDetector(FasterRCNN):
...
但是覆盖来自fasterrcnn_resnet50_fpn()
工厂的超类实例。我试过使用__new__
,如:
class MyDetector(FasterRCNN):
def __new__(cls):
return fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
def __init__(self):
num_features_in = self.roi_heads.box_predictor.cls_score.in_features
self.roi_heads.box_predictor = FastRCNNPredictor(num_features_in, num_classes=2)
def some_func(self):
pass
以便我可以向子类添加自定义方法,等等。正确的做法是什么?
我想你最好是让你自己的工厂功能。
导入库
from typing import Optional, Any
import torch
from torch import nn
import torchvision
from torchvision.models.resnet import resnet50, ResNet50_Weights
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights, FasterRCNN
from torchvision.models._utils import _ovewrite_value_param
from torchvision.models.detection.backbone_utils import (
_validate_trainable_layers,
_resnet_fpn_extractor,
)
from torchvision.models.detection._utils import overwrite_eps
from torchvision.ops import misc as misc_nn_ops
类MyDetector
class MyDetector(FasterRCNN):
def __init__(self, backbone, num_classes=None, **kwarg):
super().__init__(backbone=backbone, num_classes=num_classes, **kwarg)
def some_func(self):
pass
MyDetector工厂函数
# https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py#L459
def mydetector_resnet50_fpn(
*,
weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> MyDetector:
weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(
"num_classes", num_classes, len(weights.meta["categories"])
)
elif num_classes is None:
num_classes = 91
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(
is_trained, trainable_backbone_layers, 5, 3
)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(
weights=weights_backbone, progress=progress, norm_layer=norm_layer
)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = MyDetector(backbone, num_classes=num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
return model
检查
的实用程序# https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6
def compare_models(model_1, model_2):
models_differ = 0
for key_item_1, key_item_2 in zip(
model_1.state_dict().items(), model_2.state_dict().items()
):
if torch.equal(key_item_1[1], key_item_2[1]):
pass
else:
models_differ += 1
if key_item_1[0] == key_item_2[0]:
print("Mismtach found at", key_item_1[0])
else:
raise Exception
if models_differ == 0:
print("Models match perfectly! :)")
测试model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
)
my_model = mydetector_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
compare_models(model, my_model)
输出Models match perfectly! :)
我也试过做硬编码版本。但如你所知,自定义FPN的设置有些复杂。
from torchvision.models.resnet import resnet50, ResNet50_Weights
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights, FasterRCNN
from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor
from torchvision.ops import misc as misc_nn_ops
class MyDetector(FasterRCNN):
def __init__(self, **kwarg):
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
backbone = resnet50(
weights=ResNet50_Weights.IMAGENET1K_V1,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
)
backbone = _resnet_fpn_extractor(backbone, trainable_layers=3)
# default of num_classes is 91
# this num_classes is used for setting FastRCNNPreditcor
# https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py#L257
num_classes = len(weights.meta["categories"])
super().__init__(backbone=backbone, num_classes=num_classes, **kwarg)
self.load_state_dict(weights.get_state_dict(progress=True))
def some_func(self):
pass
m = MyDetector()