我已经写了几行代码来提取单个图像的5个边界框/补丁。当我运行这个代码并打印它的输出形状时,它有点像(5, 256)
,五个补丁,每个补丁的向量为256。问题是从单个图像中单独提取补丁,当我向该代码提供5000+个图像时,它会生成相互混合的5000*5
补丁。这样,它就失去了补丁/图像的关系。我想以某种方式更改此代码,以生成具有类似(1, 5, 256)
的批处理信息的输出。通过这种方式,每个批次将代表一个图像。
def create_vision_encoder(
num_projection_layers, projection_dims, dropout_rate, trainable=False
):
xception = keras.applications.Xception(
include_top=False, weights="imagenet", pooling="avg"
)
for layer in xception.layers:
layer.trainable = trainable
inputs = layers.Input(shape=(299, 299, 3), name="image_input")
NUM_BOXES = 5
CHANNELS = 3
CROP_SIZE = (200, 200)
boxes = tf.random.uniform(shape=(NUM_BOXES, 4))
box_indices = tf.random.uniform(shape=(NUM_BOXES,), minval=0,
maxval=BATCH_SIZE, dtype=tf.int32)
output = tf.image.crop_and_resize(inputs, boxes, box_indices, CROP_SIZE)
xception_input = tf.keras.applications.xception.preprocess_input(output)
embeddings = xception(xception_input)
outputs = project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
)
return keras.Model(inputs, outputs, name="vision_encoder")
您可以创建一个ImagePatchesAndEmbedding
层,该层将堆叠捕获的边界框并应用xception
:
class ImagePatchesAndEmbedding(keras.layers.Layer):
def __init__(self, crop_size, num_boxes=5, minval=0, maxval=1):
super(ImagePatchesAndEmbedding, self).__init__()
self.crop_size = crop_size
self.boxes = tf.random.uniform(shape=(num_boxes, 4))
self.box_indices = tf.random.uniform(shape=(num_boxes,), minval=0,
maxval=1, dtype=tf.int32)
self.preprocess = tf.keras.applications.xception.preprocess_input
def call(self, inputs):
patches = tf.map_fn(lambda img:tf.image.crop_and_resize(img[None,...],
self.boxes, self.box_indices, self.crop_size), inputs)
embeddings = tf.map_fn(lambda patch: xception(self.preprocess(patch)), patches)
return embeddings
型号,
inputs = layers.Input(shape=(299, 299, 3), name="image_input")
NUM_BOXES = 5
CHANNELS = 3
CROP_SIZE = (200, 200)
BATCH_SIZE = 3
output = ImagePatchesAndEmbedding(CROP_SIZE, num_boxes=5, maxval=BATCH_SIZE)(inputs)
model = keras.Model(inputs, output)
呼叫模型,
model(tf.random.normal(shape=(BATCH_SIZE, 299, 299, 3))).shape
#[3, 5, 2048]