覆盖Python中的超类实例



我可以使用

创建一个FasterRCNN对象
model = fasterrcnn_resnet50_fpn(...)

我想继承的,如

class MyDetector(FasterRCNN):
   ...

但是覆盖来自fasterrcnn_resnet50_fpn()工厂的超类实例。我试过使用__new__,如:

class MyDetector(FasterRCNN):
    def __new__(cls):
        return fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
    def __init__(self):
        num_features_in = self.roi_heads.box_predictor.cls_score.in_features
        self.roi_heads.box_predictor = FastRCNNPredictor(num_features_in, num_classes=2)
    def some_func(self):
        pass

以便我可以向子类添加自定义方法,等等。正确的做法是什么?

我想你最好是让你自己的工厂功能。

导入库

from typing import Optional, Any
import torch
from torch import nn
import torchvision
from torchvision.models.resnet import resnet50, ResNet50_Weights
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights, FasterRCNN
from torchvision.models._utils import _ovewrite_value_param
from torchvision.models.detection.backbone_utils import (
    _validate_trainable_layers,
    _resnet_fpn_extractor,
)
from torchvision.models.detection._utils import overwrite_eps
from torchvision.ops import misc as misc_nn_ops
类MyDetector

class MyDetector(FasterRCNN):
    def __init__(self, backbone, num_classes=None, **kwarg):
        super().__init__(backbone=backbone, num_classes=num_classes, **kwarg)
    def some_func(self):
        pass

MyDetector工厂函数

# https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py#L459
def mydetector_resnet50_fpn(
    *,
    weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
    progress: bool = True,
    num_classes: Optional[int] = None,
    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
    trainable_backbone_layers: Optional[int] = None,
    **kwargs: Any,
) -> MyDetector:
    weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
    weights_backbone = ResNet50_Weights.verify(weights_backbone)
    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(
            "num_classes", num_classes, len(weights.meta["categories"])
        )
    elif num_classes is None:
        num_classes = 91
    is_trained = weights is not None or weights_backbone is not None
    trainable_backbone_layers = _validate_trainable_layers(
        is_trained, trainable_backbone_layers, 5, 3
    )
    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
    backbone = resnet50(
        weights=weights_backbone, progress=progress, norm_layer=norm_layer
    )
    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
    model = MyDetector(backbone, num_classes=num_classes, **kwargs)
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
        if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
            overwrite_eps(model, 0.0)
    return model

检查

的实用程序
# https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6
def compare_models(model_1, model_2):
    models_differ = 0
    for key_item_1, key_item_2 in zip(
        model_1.state_dict().items(), model_2.state_dict().items()
    ):
        if torch.equal(key_item_1[1], key_item_2[1]):
            pass
        else:
            models_differ += 1
            if key_item_1[0] == key_item_2[0]:
                print("Mismtach found at", key_item_1[0])
            else:
                raise Exception
    if models_differ == 0:
        print("Models match perfectly! :)")

测试
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
)
my_model = mydetector_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
compare_models(model, my_model)

输出
Models match perfectly! :)

我也试过做硬编码版本。但如你所知,自定义FPN的设置有些复杂。

from torchvision.models.resnet import resnet50, ResNet50_Weights
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights, FasterRCNN
from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor
from torchvision.ops import misc as misc_nn_ops

class MyDetector(FasterRCNN):
    def __init__(self, **kwarg):
        weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
        backbone = resnet50(
            weights=ResNet50_Weights.IMAGENET1K_V1,
            norm_layer=misc_nn_ops.FrozenBatchNorm2d,
        )
        backbone = _resnet_fpn_extractor(backbone, trainable_layers=3)
        # default of num_classes is 91
        # this num_classes is used for setting FastRCNNPreditcor
        # https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py#L257
        num_classes = len(weights.meta["categories"])
        super().__init__(backbone=backbone, num_classes=num_classes, **kwarg)
        self.load_state_dict(weights.get_state_dict(progress=True))
    def some_func(self):
        pass
m = MyDetector()

相关内容

  • 没有找到相关文章

最新更新