我正在写一个RESNET,但我不明白";呼叫";函数。
也许这是TensorFlow自动调用的,所以这意味着我们必须编写一个名为"的函数;呼叫";?如果是这样的话,对这个";呼叫";作用非常感谢。
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequential
class BasicBlock(layers.Layer):
def __init__(self, filter_num, strides=1):
super(BasicBlock, self).__init__()
self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=strides, padding="same")
self.bn1 = layers.BatchNormalization()
self.relu = layers.Activation('relu')
self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding="same")
self.bn2 = layers.BatchNormalization()
if strides != 1:
self.downsample = layers.Conv2D(filter_num, (1, 1), strides=strides)
else:
self.downsample = lambda x:x
def call(self, inputs, training=None):
out = self.conv1(inputs)
out = self.bn1(out, training=training)
out = self.relu(out)
out = self.conv2(out, training=training)
out = self.bn2(out)
identity = self.downsample(inputs)
output = layers.add([out, identity])
output = tf.nn.relu(output)
return output
当您定义自定义层时,您将扩展基类tensorflow.keras.layers.Layer
并按如下方式使用它:
import tensorflow as tf
class BasicBlock(tf.keras.layers.Layer):
...
basic_block = BasicBlock()
basic_block(inputs)
上面片段的最后一行将调用类中的魔术方法__call__
(如果您感兴趣,请点击此处了解有关魔术方法的更多信息Python的魔术方法指南(
由于您没有在BasicBlock
中定义__call__
方法(您定义了不同的call
(,因此将使用tensorflow.keras.layers.Layer
中的__call__
。
根据Tensorflow文档,该方法具有以下文档
Wraps调用,应用预处理和后处理步骤。
粗略地说,你会有(如果你感兴趣,你可以检查源代码,但它要复杂得多(:
class Layer(...):
....
def __call__(self, ...):
# preprocessing steps
self.call(...)
# post processing steps
如果您熟悉继承,您应该猜测使用basic_block(inputs)
:时的不同步骤
- 检查
BasicBlock
是否具有名为__call__
=>否 - 检查基类CCD_ 11是否具有名为CCD_;是的,使用它并进入这个方法
- 应用预处理步骤
- 检查CCD_ 13是否具有名为CCD_;是,使用它并将其应用于输入
- 应用后期处理步骤
关于实现最佳资源的call
方法的要求,请参阅Tensorflow官方文档,在该文档中,您已经解释了预期输入数据结构+关键字参数的所有内容
调用函数的使用方式如下:
basic_block = BasicBlock()
basic_block(args)
所以它来代替:
basic_block.call(args)
在TensorFlow中,当您将模型实例作为带有输入数据的函数调用时,会自动调用自定义模型的call((方法。这通常使用圆括号((运算符来完成,这是调用call((方法的简写。
例如,在问题中提供的代码中:
basicblock = BasicBlock(filter_num)
basicblock(input_data) # this will invoke the call method
它看起来像这样:
BasicBlock(filter_num)(input_data)