如何使用Tensorflow对象检测API继续训练对象检测模型



我使用Tensorflow对象检测API来训练使用迁移学习的对象检测模型。具体来说,我使用模型zoo中的ssd_mobilenet_v1_fpn_coco,并使用提供的示例管道,当然,我已经用到我的训练和评估tfrecords和标签的实际链接替换了占位符。

我能够使用上述管道在我的5000张图像(以及相应的边界框(上成功地训练模型(如果有兴趣的话,我主要在TPU上使用谷歌的ML引擎(。

现在,我准备了大约2000张额外的图像,并希望继续用这些新图像训练我的模型,而不需要从头开始(训练初始模型需要大约6小时的TPU时间(。我该怎么做?

您有两个选项,在这两个选项中,您都需要更改新数据集的train_input_readerinput_path

  1. 在训练配置中指定要微调的检查点时,请指定训练模型的检查点
train_config{
fine_tune_checkpoint: <path_to_your_checkpoint>
fine_tune_checkpoint_type: "detection"
load_all_detection_checkpoint_vars: true
}
  1. 只需使用与上一型号相同的model_dir保持相同的配置(train_input_reader除外(。这样,API将创建一个图,并检查检查点是否已经存在于model_dir中并适合该图。如果是这样,它将恢复并继续训练它

Edit:fine_tune_checkpoint_type之前被错误地设置为true,而它通常应该是"detection"或"classification",在这种特定情况下应该是"detection"。感谢Krish的注意。

我还没有在新的数据集上重新训练对象检测模型,但它看起来像在配置文件中增加训练步骤CCD_ 6的数量并在tfrecord文件中添加图像就足够了。

最新更新