链接:https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/eager_few_shot_od_training_tf2_colab.ipynb
我已经尝试了上面的谷歌colab来训练一个具有1个类的对象检测模型,如示例所示。
我正在尝试了解如何修改此代码,以便能够训练2个类。
在上面的例子中,在我用方框注释图像后,它运行以下代码来创建category_index
和图像/方框张量。假设我修改了num_classes = 2
,并在category_index
中添加了另一个类,那么如何从这里开始呢?例如,我认为一个热门编码仅适用于1类。如何修改代码使其与2个类一起工作?
# By convention, our non-background classes start counting at 1. Given
# that we will be predicting just one class, we will therefore assign it a
# `class id` of 1.
duck_class_id = 1
num_classes = 1
category_index = {duck_class_id: {'id': duck_class_id, 'name': 'rubber_ducky'}}
# Convert class labels to one-hot; convert everything to tensors.
# The `label_id_offset` here shifts all classes by a certain number of indices;
# we do this here so that the model receives one-hot labels where non-background
# classes start counting at the zeroth index. This is ordinarily just handled
# automatically in our training binaries, but we need to reproduce it here.
label_id_offset = 1
train_image_tensors = []
gt_classes_one_hot_tensors = []
gt_box_tensors = []
for (train_image_np, gt_box_np) in zip(
train_images_np, gt_boxes):
train_image_tensors.append(tf.expand_dims(tf.convert_to_tensor(
train_image_np, dtype=tf.float32), axis=0))
gt_box_tensors.append(tf.convert_to_tensor(gt_box_np, dtype=tf.float32))
zero_indexed_groundtruth_classes = tf.convert_to_tensor(
np.ones(shape=[gt_box_np.shape[0]], dtype=np.int32) - label_id_offset)
gt_classes_one_hot_tensors.append(tf.one_hot(
zero_indexed_groundtruth_classes, num_classes))
print('Done prepping data.')
为了mono-class检测教程:橡皮鸭探测器或僵尸探测器。将其更改为使用多类,需要进行类似的更改(两周后的解决方案(:
category_index
变量必须如下所示
gt_classes = [1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3]
# gt_classes = [[1],[1],[1],[1],[1], [2],[2],[2],[2],[2],[2],[2],[2], [3],[3],[3],[3],[3],[3],[3],[3,2]]
zombie_CLASS_ID = 1
cat_CLASS_ID = 2
dog_CLASS_ID = 3
category_index = {zombie_CLASS_ID :
{'id' : zombie_CLASS_ID,'name': 'zombie'},
cat_CLASS_ID :
{'id' : cat_CLASS_ID,'name': 'cat'},
dog_CLASS_ID :
{'id' : dog_CLASS_ID,'name': 'dog'}
}
NUM_CLASSES = len(category_index)
- 这里的
np.ones(shape=[gt_box_np.shape[0]], dtype=np.int32)
是无意义的(也在Rubber Ducky检测器中(,这是一种非常尴尬的方式,作者发现将grount true classes变量格式化为张量。GT_classes条目的格式必须为Tensor("Const:0", shape=(1, NUM_CLASES), dtype=float32)
和one_hot编码器(float32很重要( - 对于它,必须同时替换为:
tf.one_hot
和tf.reshape
。示例创建正确的gt_classes_one_hot_tensors
:
label_id_offset = 1 #TF actually starts with 0
train_image_tensors = []
gt_classes_one_hot_tensors = []
gt_box_tensors = []
for (train_image_np, gt_box_np, gt_class) in zip(list_train_images_np, gt_boxes, gt_classes):
train_image_tensors.append(tf.expand_dims(tf.convert_to_tensor(train_image_np, dtype=tf.float32), axis=0))
gt_box_tensors.append(tf.convert_to_tensor(gt_box_np, dtype=tf.float32))
#HERE the most critical change in gt_classes , tf.reshape to keep format (1, NUM_CLASES)
gt_class_hot = tf.one_hot(indices=(gt_class - label_id_offset), depth= NUM_CLASES, dtype=tf.float32)
gt_classes_one_hot_tensors.append( tf.reshape( gt_class_hot , [-1, NUM_CLASES]) )
print('Done prepping data Num_loaded : ', len(list_train_images_np) )
如果你是从这些教程开始的,我建议你阅读:Does tensorflow';s对象检测api支持多类多标签检测?
更多信息:https://github.com/tensorflow/models/issues/9655#issuecomment-1460289284