仅掩码RCNN 1类



我希望只使用一个类,person(以及BG,背景),用于掩码RCNN对象检测。我使用这个链接:https://github.com/matterport/Mask_RCNN来运行掩码rcnn。是否有特定的方法来完成这一点(编辑特定的文件,创建一个额外的python文件,或者只是通过过滤class_names数组中的选择)?任何方向或解决方案将高度赞赏。谢谢你

我已经为绵羊训练了相同的repo。你必须做两件事:

  1. 将训练和推理类编号更改为1 + 1 (bg和person):

    class SheepsConfig(Config):
    NAME = "sheeps"
    NUM_CLASSES = 1 + 1 # background + sheep
    config = SheepsConfig()  # Don't forget to use this config while creating your model
    config.display()
    
  2. 你需要创建数据集来训练。你可以这样使用coco:

    import coco
    from pycocotools.coco import COCO
    ct = COCO("/YourPathToCocoDataset/annotations/instances_train2014.json")
    ct.getCatIds(['sheep']) 
    # Sheep class' id is 20. You should run for person and use that id
    COCO_DIR = "/YourPathToCocoDataset/"
    # This path has train2014, annotations and val2014 files in it
    # Training dataset
    dataset_train = coco.CocoDataset()
    dataset_train.load_coco(COCO_DIR, "train", class_ids=[20])
    dataset_train.prepare()
    # Validation dataset
    dataset_val = coco.CocoDataset()
    dataset_val.load_coco(COCO_DIR, "val", class_ids=[20])
    dataset_val.prepare()
    

然后简单地创建你的模型:

# Create model in training mode
model = modellib.MaskRCNN(mode="training", config=config, model_dir=MODEL_DIR)
model.load_weights(COCO_MODEL_PATH, by_name=True, exclude=["mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask"])
# This COCO_MODEL_PATH is the path to the mask_rcnn_coco.h5 file in this repo

然后你可以用下面的代码来训练它:

model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE, 
epochs=100, 
layers='heads')#You can also use 'all' to train all network.

不要忘记使用tensorflow 1。x和keras 2.1.0:)我可以用这些版本进行训练。

我试着按照@dnl_anoj的建议"只显示个人类别的结果"。我从预测结果中删除了除person类之外的所有类。您可以在https://github.com/matterport/Mask_RCNN.

中的predictor.py文件中的run_on_opencv_image()函数中使用以下代码。
predictions = self.coco_demo.compute_prediction(image)
top_predictions = self.coco_demo.select_top_predictions(predictions)
masks = top_predictions.get_field("mask")
boxes = top_predictions.bbox
label_indexs = top_predictions.get_field("labels").numpy()
x = np.where(label_indexs != 1) # get indexes of labels which are not person
#remove items which are not person class
masks = np.delete(masks,x, axis=0)
boxes = np.delete(boxes,x, axis=0)
label_indexs = np.delete(label_indexs,x)
labels = self.convert_label_index_to_string(label_indexs)

你链接的github的作者制作了一个气球的例子,它写得很好,只包含一个类(气球),你应该遵循这个教程:https://engineering.matterport.com/splash-of-color-instance-segmentation-with-mask-r-cnn-and-tensorflow-7c761e238b46

相关内容

  • 没有找到相关文章

最新更新