BatchNormalization在tensorflow模型中



我想在tensorflow中制作自定义模型。我创建了一些函数来创建基本的图层,比如Conv2D, Dense, Flatten。我坚持批处理规范化实现。

我想把所有的trainable_variables(参数)放在一个列表self.parameters中。因为我的"习惯"层是建立在tf上的。模块1假设self.trainable_parameters中所有可训练的变量都可用。当前self.trainable_variables包含MyBatchNormalization可训练变量

下面Colab示例:

Colab Example - Section "创建模型"打印图层名称和可训练参数。

我想有工作BatchNormalization层,其中可训练的变量将在train_on_batch方法中更新(训练)。

我找到了解决方案-我只是添加了__build(self, input_shape)函数,它调用super(MyBatchNormalization, self).build(input_shape)

Implemetation如下:

class MyBatchNormalization(tf.keras.layers.BatchNormalization):
def __init__(self, input_shape, name=None):
super().__init__(name=name)
self.out_shape = input_shape
self.__build(input_shape)
def __build(self, input_shape):
super(MyBatchNormalization, self).build(input_shape)

我还没有完全测试过。但是我可以访问MyBatchNormalization的trainable_variables,它看起来很有希望。

相关内容

  • 没有找到相关文章

最新更新