我根据本指南https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/training.html#configuring-a-training-job运行tensorflow对象检测api,但是稍微修改了用于制作记录文件的代码,并使用以下系统:
系统信息:
- 操作系统平台和发行版:Ubuntu 20.04.1 LTS Python版本:
- TensorFlow版本:2.4.0
- CUDA/cuDNN版本:11.0/8.0.5
- GPU型号和内存:GeForce RTX 3090, 24268 MiB
我想使用模型CenterNet MobileNetV2 FPN 512x512的盒子从TensorFlow2检测模型动物园(https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md)。
我根据指南设置培训任务,然后运行
python model_main_tf2.py --model_dir=models/my_centernet_mn_fpn --pipeline_config_path=models/my_centernet_mn_fpn/pipeline.config
,当我这样做时,我得到以下错误
Traceback (most recent call last):
File "model_main_tf2.py", line 115, in <module>
tf.compat.v1.app.run()
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "model_main_tf2.py", line 106, in main
model_lib_v2.train_loop(
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/object_detection/model_lib_v2.py", line 636, in train_loop
loss = _dist_train_step(train_input_iter)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
result = self._call(*args, **kwds)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 888, in _call
return self._stateless_fn(*args, **kwds)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2942, in __call__
return graph_function._call_flat(
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
return self._build_call_outputs(self._inference_function.call(
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
outputs = execute.execute(
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: indices[0] = 0 is not in [0, 0)
[[{{node GatherV2_7}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNext]]
(1) Invalid argument: indices[0] = 0 is not in [0, 0)
[[{{node GatherV2_7}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNext]]
[[ToAbsoluteCoordinates_42/Assert/AssertGuard/branch_executed/_386/_1231]]
0 successful operations.
0 derived errors ignored. [Op:__inference__dist_train_step_54439]
Function call stack:
_dist_train_step -> _dist_train_step
当谷歌这个错误有一些答案,错误是在创建TFRecord文件,你需要添加include_masks
时创建它们。然而,当从模型动物园运行其他CenterNet模型时,我没有得到这个错误,所以这似乎很奇怪,这将是错误。
你知道会是什么吗?
感谢Alexandra。从配置文件中删除关键点关联
原始代码(删除标记&;//delete"的部分):
model {
center_net {
num_classes: 90
feature_extractor {
type: "mobilenet_v2_fpn_sep_conv"
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
use_depthwise: true
object_detection_task {
task_loss_weight: 1.0
offset_loss_weight: 1.0
scale_loss_weight: 0.1
localization_loss {
l1_localization_loss {
}
}
}
object_center_params {
object_center_loss_weight: 1.0
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 2.0
beta: 4.0
}
}
min_box_overlap_iou: 0.7
max_box_predictions: 20
}
}
}
train_config {
batch_size: 512
data_augmentation_options { // delete
random_horizontal_flip {
keypoint_flip_permutation: 0
keypoint_flip_permutation: 2
keypoint_flip_permutation: 1
keypoint_flip_permutation: 4
keypoint_flip_permutation: 3
keypoint_flip_permutation: 6
keypoint_flip_permutation: 5
keypoint_flip_permutation: 8
keypoint_flip_permutation: 7
keypoint_flip_permutation: 10
keypoint_flip_permutation: 9
keypoint_flip_permutation: 12
keypoint_flip_permutation: 11
keypoint_flip_permutation: 14
keypoint_flip_permutation: 13
keypoint_flip_permutation: 16
keypoint_flip_permutation: 15
}
} // delete
data_augmentation_options {
random_patch_gaussian {
}
}
data_augmentation_options {
random_crop_image {
min_aspect_ratio: 0.5
max_aspect_ratio: 1.7
random_coef: 0.25
}
}
data_augmentation_options {
random_adjust_hue {
}
}
data_augmentation_options {
random_adjust_contrast {
}
}
data_augmentation_options {
random_adjust_saturation {
}
}
data_augmentation_options {
random_adjust_brightness {
}
}
data_augmentation_options {
random_absolute_pad_image {
max_height_padding: 200
max_width_padding: 200
pad_color: 0.0
pad_color: 0.0
pad_color: 0.0
}
}
optimizer {
adam_optimizer {
learning_rate {
cosine_decay_learning_rate {
learning_rate_base: 5e-3
total_steps: 300000
warmup_learning_rate: 1e-4
warmup_steps: 5000
}
}
}
use_moving_average: false
}
num_steps: 300000
max_number_of_boxes: 100
unpad_groundtruth_tensors: false
fine_tune_checkpoint_type: ""
}
train_input_reader {
label_map_path: "PATH_TO_BE_CONFIGURED/label_map.txt"
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/train2017-?????-of-00256.tfrecord"
}
filenames_shuffle_buffer_size: 256
num_keypoints: 17 // delete
}
eval_config {
num_visualizations: 10
metrics_set: "coco_detection_metrics"
use_moving_averages: false
min_score_threshold: 0.20000000298023224
max_num_boxes_to_visualize: 20
batch_size: 1
}
eval_input_reader {
label_map_path: "PATH_TO_BE_CONFIGURED/label_map.txt"
shuffle: false
num_epochs: 1
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/val2017-?????-of-00032.tfrecord"
}
num_keypoints: 17 // delete
}
下面的代码是如何工作的(修改):
model {
center_net {
num_classes: 4
feature_extractor {
type: "mobilenet_v2_fpn_sep_conv"
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
use_depthwise: true
object_detection_task {
task_loss_weight: 1.0
offset_loss_weight: 1.0
scale_loss_weight: 0.1
localization_loss {
l1_localization_loss {
}
}
}
object_center_params {
object_center_loss_weight: 1.0
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 2.0
beta: 4.0
}
}
min_box_overlap_iou: 0.7
max_box_predictions: 20
}
}
}
train_config {
batch_size: 8
data_augmentation_options {
random_patch_gaussian {
}
}
data_augmentation_options {
random_crop_image {
min_aspect_ratio: 0.5
max_aspect_ratio: 1.7
random_coef: 0.25
}
}
data_augmentation_options {
random_adjust_hue {
}
}
data_augmentation_options {
random_adjust_contrast {
}
}
data_augmentation_options {
random_adjust_saturation {
}
}
data_augmentation_options {
random_adjust_brightness {
}
}
data_augmentation_options {
random_absolute_pad_image {
max_height_padding: 200
max_width_padding: 200
pad_color: 0.0
pad_color: 0.0
pad_color: 0.0
}
}
optimizer {
adam_optimizer {
learning_rate {
cosine_decay_learning_rate {
learning_rate_base: 5e-3
total_steps: 50000
warmup_learning_rate: 1e-4
warmup_steps: 2000
}
}
}
use_moving_average: false
}
num_steps: 6000
max_number_of_boxes: 4
unpad_groundtruth_tensors: false
fine_tune_checkpoint_version: V2
fine_tune_checkpoint: "pre-trained-models/centernet_mobilenetv2_fpn_od/checkpoint/ckpt-301"
fine_tune_checkpoint_type: "detection"
}
train_input_reader {
label_map_path: "annotations/label_map.pbtxt"
tf_record_input_reader {
input_path: "annotations/train_c.record"
}
filenames_shuffle_buffer_size: 256
}
eval_config {
num_visualizations: 10
metrics_set: "coco_detection_metrics"
use_moving_averages: false
min_score_threshold: 0.20000000298023224
max_num_boxes_to_visualize: 20
batch_size: 1
}
eval_input_reader {
label_map_path: "annotations/label_map.pbtxt"
shuffle: false
num_epochs: 1
tf_record_input_reader {
input_path: "annotations/test_c.record"
}
}