我正在运行一个 Keras 模型,我想在编译之前生成模型输出,在编译和训练模型之前使用一些虚拟输入。在这里,我分享了一些我的模型代码,我认为这些代码对于理解我的模型很重要, 有关完整的代码,请访问此 colab 文件,您也可以在此处查看官方 keras 代码。
class ShiftViTModel(keras.Model):
"""The ShiftViT Model.
Args:
data_augmentation (keras.Model): A data augmentation model.
projected_dim (int): The dimension to which the patches of the image are
projected.
patch_size (int): The patch size of the images.
num_shift_blocks_per_stages (list[int]): A list of all the number of shit
blocks per stage.
epsilon (float): The epsilon constant.
mlp_dropout_rate (float): The dropout rate used in the MLP block.
stochastic_depth_rate (float): The maximum drop rate probability.
num_div (int): The number of divisions of the channesl of the feature
map. Defaults to 12.
shift_pixel (int): The number of pixel to shift. Default to 1.
mlp_expand_ratio (int): The ratio with which the initial mlp dense layer
is expanded to. Defaults to 2.
"""
def __init__(
self,
data_augmentation,
projected_dim,
patch_size,
num_shift_blocks_per_stages,
epsilon,
mlp_dropout_rate,
stochastic_depth_rate,
num_div=12,
shift_pixel=1,
mlp_expand_ratio=2,
**kwargs,
):
super().__init__(**kwargs)
self.data_augmentation = data_augmentation
self.patch_projection = layers.Conv2D(
filters=projected_dim,
kernel_size=patch_size,
strides=patch_size,
padding="same",
)
self.stages = list()
for index, num_shift_blocks in enumerate(num_shift_blocks_per_stages):
if index == len(num_shift_blocks_per_stages) - 1:
# This is the last stage, do not use the patch merge here.
is_merge = False
else:
is_merge = True
# Build the stages.
self.stages.append(
StackedShiftBlocks(
epsilon=epsilon,
mlp_dropout_rate=mlp_dropout_rate,
num_shift_blocks=num_shift_blocks,
stochastic_depth_rate=stochastic_depth_rate,
is_merge=is_merge,
num_div=num_div,
shift_pixel=shift_pixel,
mlp_expand_ratio=mlp_expand_ratio,
)
)
self.global_avg_pool = layers.GlobalAveragePooling2D()
def get_config(self):
config = super().get_config()
config.update(
{
"data_augmentation": self.data_augmentation,
"patch_projection": self.patch_projection,
"stages": self.stages,
"global_avg_pool": self.global_avg_pool,
}
)
return config
def _calculate_loss(self, data, training=False):
(images, labels) = data
# Augment the images
augmented_images = self.data_augmentation(images, training=training)
# Create patches and project the pathces.
projected_patches = self.patch_projection(augmented_images)
# Pass through the stages
x = projected_patches
for stage in self.stages:
x = stage(x, training=training)
# Get the logits.
logits = self.global_avg_pool(x)
# Calculate the loss and return it.
total_loss = self.compiled_loss(labels, logits)
return total_loss, labels, logits
def train_step(self, inputs):
with tf.GradientTape() as tape:
total_loss, labels, logits = self._calculate_loss(
data=inputs, training=True
)
# Apply gradients.
train_vars = [
self.data_augmentation.trainable_variables,
self.patch_projection.trainable_variables,
self.global_avg_pool.trainable_variables,
]
train_vars = train_vars + [stage.trainable_variables for stage in self.stages]
# Optimize the gradients.
grads = tape.gradient(total_loss, train_vars)
trainable_variable_list = []
for (grad, var) in zip(grads, train_vars):
for g, v in zip(grad, var):
trainable_variable_list.append((g, v))
self.optimizer.apply_gradients(trainable_variable_list)
# Update the metrics
self.compiled_metrics.update_state(labels, logits)
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
_, labels, logits = self._calculate_loss(data=data, training=False)
# Update the metrics
self.compiled_metrics.update_state(labels, logits)
return {m.name: m.result() for m in self.metrics}
第二块
model = ShiftViTModel(
data_augmentation=get_augmentation_model(),
projected_dim=config.projected_dim,
patch_size=config.patch_size,
num_shift_blocks_per_stages=config.num_shift_blocks_per_stages,
epsilon=config.epsilon,
mlp_dropout_rate=config.mlp_dropout_rate,
stochastic_depth_rate=config.stochastic_depth_rate,
num_div=config.num_div,
shift_pixel=config.shift_pixel,
mlp_expand_ratio=config.mlp_expand_ratio,
)
实际上我正在尝试从上述模型产生输出,如下所示
dummy_inputs = tf.ones((2, 32, 32, 3))
outputs = model(dummy_inputs, training=False)
print(outputs.shape)
但它会产生一个错误
未实现的
tf.keras.Model.call()
:如果您打算创建一个Model
函数式 API,请提供inputs
和outputs
参数。否则,子类Model
被覆盖call()
方法。
call
方法没有实现,如果我们需要使用虚拟数据检查模型,则在此类实现中是必需的。您可以在ShiftViTModel
类中实现call
方法,如下所示,使用使用的层(请参阅train_step
方法)。
def call(self, images):
augmented_images = self.data_augmentation(images)
x = self.patch_projection(augmented_images)
logits = self.global_avg_pool(x)
return logits
现在,如果我们这样做
model = ShiftViTModel( ... )
x,y = next(iter(train_ds))
print(x.shape, y.shape)
model(x).shape
(256, 32, 32, 3) (256, 1)
TensorShape([256, 96])