call()如何为自定义Keras层工作?



我试图通过继承tf.keras.layers.Layer来构建自己的Keras层,我真的不明白call()方法在做什么。我已经将call()方法设置为:

def call(self,inputs): 
print('call')
return inputs

当我运行网络时,我希望'call'被打印多次(对于100个示例和10个epoch的训练集,我希望它被打印1000次)。然而,"call"在模型构建时打印一次,然后在第一个epoch中打印3次,然后再也不打印了。我的网络在随后的时代中不使用这一层吗?为什么它在第一个阶段只被调用了3次尽管有100个训练样本?

Call方法由@tf.function自动装饰。这意味着keras在第一次调用时构建数据流图,并在下一次调用时运行该图。

只在第一次调用时调用python函数。详情请点击这里- https://www.tensorflow.org/guide/function#debugging.

相关内容

  • 没有找到相关文章

最新更新