如何在一个以上的课堂上进行物体检测模型训练



链接:https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/eager_few_shot_od_training_tf2_colab.ipynb

我已经尝试了上面的谷歌colab来训练一个具有1个类的对象检测模型,如示例所示。

我正在尝试了解如何修改此代码,以便能够训练2个类。

在上面的例子中,在我用方框注释图像后,它运行以下代码来创建category_index和图像/方框张量。假设我修改了num_classes = 2,并在category_index中添加了另一个类,那么如何从这里开始呢?例如,我认为一个热门编码仅适用于1类。如何修改代码使其与2个类一起工作?

# By convention, our non-background classes start counting at 1.  Given
# that we will be predicting just one class, we will therefore assign it a
# `class id` of 1.
duck_class_id = 1
num_classes = 1
category_index = {duck_class_id: {'id': duck_class_id, 'name': 'rubber_ducky'}}
# Convert class labels to one-hot; convert everything to tensors.
# The `label_id_offset` here shifts all classes by a certain number of indices;
# we do this here so that the model receives one-hot labels where non-background
# classes start counting at the zeroth index.  This is ordinarily just handled
# automatically in our training binaries, but we need to reproduce it here.
label_id_offset = 1
train_image_tensors = []
gt_classes_one_hot_tensors = []
gt_box_tensors = []
for (train_image_np, gt_box_np) in zip(
train_images_np, gt_boxes):
train_image_tensors.append(tf.expand_dims(tf.convert_to_tensor(
train_image_np, dtype=tf.float32), axis=0))
gt_box_tensors.append(tf.convert_to_tensor(gt_box_np, dtype=tf.float32))
zero_indexed_groundtruth_classes = tf.convert_to_tensor(
np.ones(shape=[gt_box_np.shape[0]], dtype=np.int32) - label_id_offset)
gt_classes_one_hot_tensors.append(tf.one_hot(
zero_indexed_groundtruth_classes, num_classes))
print('Done prepping data.')

为了mono-class检测教程:橡皮鸭探测器或僵尸探测器。将其更改为使用多类,需要进行类似的更改(两周后的解决方案(:

  • category_index变量必须如下所示
gt_classes = [1,1,1,1,1,  2,2,2,2,2,2,2,2,  3,3,3,3,3,3,3,3]
# gt_classes = [[1],[1],[1],[1],[1], [2],[2],[2],[2],[2],[2],[2],[2], [3],[3],[3],[3],[3],[3],[3],[3,2]]
zombie_CLASS_ID = 1
cat_CLASS_ID = 2
dog_CLASS_ID = 3
category_index = {zombie_CLASS_ID :
{'id'  : zombie_CLASS_ID,'name': 'zombie'},
cat_CLASS_ID :
{'id'  : cat_CLASS_ID,'name': 'cat'},
dog_CLASS_ID :
{'id'  : dog_CLASS_ID,'name': 'dog'}
}
NUM_CLASSES = len(category_index)
  • 这里的np.ones(shape=[gt_box_np.shape[0]], dtype=np.int32)是无意义的(也在Rubber Ducky检测器中(,这是一种非常尴尬的方式,作者发现将grount true classes变量格式化为张量。GT_classes条目的格式必须为Tensor("Const:0", shape=(1, NUM_CLASES), dtype=float32)one_hot编码器(float32很重要(
  • 对于它,必须同时替换为:tf.one_hottf.reshape。示例创建正确的gt_classes_one_hot_tensors
label_id_offset = 1 #TF actually starts with 0 
train_image_tensors = []
gt_classes_one_hot_tensors = []
gt_box_tensors = []
for (train_image_np, gt_box_np, gt_class) in zip(list_train_images_np, gt_boxes, gt_classes):
train_image_tensors.append(tf.expand_dims(tf.convert_to_tensor(train_image_np, dtype=tf.float32), axis=0))
gt_box_tensors.append(tf.convert_to_tensor(gt_box_np, dtype=tf.float32))
#HERE the most critical change in gt_classes , tf.reshape to keep format (1, NUM_CLASES) 
gt_class_hot = tf.one_hot(indices=(gt_class - label_id_offset), depth= NUM_CLASES, dtype=tf.float32)
gt_classes_one_hot_tensors.append( tf.reshape( gt_class_hot , [-1, NUM_CLASES])   )
print('Done prepping data  Num_loaded : ', len(list_train_images_np) )

如果你是从这些教程开始的,我建议你阅读:Does tensorflow';s对象检测api支持多类多标签检测?

更多信息:https://github.com/tensorflow/models/issues/9655#issuecomment-1460289284

最新更新