我想在tensorflow中制作自定义模型。我创建了一些函数来创建基本的图层,比如Conv2D, Dense, Flatten。我坚持批处理规范化实现。
我想把所有的trainable_variables
(参数)放在一个列表self.parameters
中。因为我的"习惯"层是建立在tf上的。模块1假设self.trainable_parameters
中所有可训练的变量都可用。当前self.trainable_variables
不包含MyBatchNormalization
可训练变量
下面Colab示例:
Colab Example - Section "创建模型"打印图层名称和可训练参数。
我想有工作BatchNormalization层,其中可训练的变量将在train_on_batch
方法中更新(训练)。
我找到了解决方案-我只是添加了__build(self, input_shape)
函数,它调用super(MyBatchNormalization, self).build(input_shape)
。
Implemetation如下:
class MyBatchNormalization(tf.keras.layers.BatchNormalization):
def __init__(self, input_shape, name=None):
super().__init__(name=name)
self.out_shape = input_shape
self.__build(input_shape)
def __build(self, input_shape):
super(MyBatchNormalization, self).build(input_shape)
我还没有完全测试过。但是我可以访问MyBatchNormalization的trainable_variables
,它看起来很有希望。