在 TensorFlow 函数式(函数式 API)模型中访问'training'属性



正如标题所述,我想知道当我使用函数API时,如何访问特权"training"参数。

因此,如果我使用子类化,我可以写这样的东西:

class MyLayer(tf.keras.layers.Layer):
def __init__(self):
...
self.BN = tf.keras.Layers.BatchNormalization()

def call(self,inputs, training=None):
self.BN(inputs, training=training)

因此,我可以控制我的batchnorm在训练和预测过程中的行为。但是如果我想使用功能API:

input = tf.Input(someshape)
normalized = tf.keras.layers.BatchNormalization()(input)
tf.keras.Model(inputs=input, outputs=normalized)

现在我真的不能再为我的batch_norm设置有特权的"training"参数了。我喜欢功能性的API,它使用起来真的很有趣,但必须围绕这种类型进行构建通常是一个破坏交易的因素。我觉得我一定错过了一些关于如何在这里解决这个问题的重要想法。我知道我可以创建一个tf.Input,它可以容纳"training"参数。但这会将其从关键字arg更改为列表中的某个元素,从而产生非常不一致的代码。有什么更聪明的解决方案吗?

编辑:应该清楚地表明,我正在寻找一个可用于"training"参数的总体想法,而不仅仅是处理BatchNormalization。

实例化模型model = tf.keras.Model(inputs=input, outputs=normalized)时,模型尚未构建。您将需要调用构建方法,通常是在使用渐变带手动完成所有操作时,或者在第一次调用拟合方法时。此时,权重将被初始化。现在,如果使用fit方法或调用模型output_tensors = mymodel(input_tensors, training=True),或者相反,如果使用predict法或use output_tensors = mymodel(input_tensors, training=False),则训练标志将设置为True或False(如果直接调用模型,这一点很明显(。

最新更新