"call"函数在TensorFlow中使用在哪里?



我正在写一个RESNET,但我不明白";呼叫";函数。

也许这是TensorFlow自动调用的,所以这意味着我们必须编写一个名为"的函数;呼叫";?如果是这样的话,对这个";呼叫";作用非常感谢。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequential
class BasicBlock(layers.Layer):
def __init__(self, filter_num, strides=1):
super(BasicBlock, self).__init__()
self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=strides, padding="same")
self.bn1 = layers.BatchNormalization()
self.relu = layers.Activation('relu')
self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding="same")
self.bn2 = layers.BatchNormalization()
if strides != 1:
self.downsample = layers.Conv2D(filter_num, (1, 1), strides=strides)
else:
self.downsample = lambda x:x
def call(self, inputs, training=None): 
out = self.conv1(inputs)
out = self.bn1(out, training=training)
out = self.relu(out)
out = self.conv2(out, training=training)
out = self.bn2(out)
identity = self.downsample(inputs)
output = layers.add([out, identity])
output = tf.nn.relu(output)
return output

当您定义自定义层时,您将扩展基类tensorflow.keras.layers.Layer并按如下方式使用它:

import tensorflow as tf
class BasicBlock(tf.keras.layers.Layer):
...
basic_block = BasicBlock()
basic_block(inputs)

上面片段的最后一行将调用类中的魔术方法__call__(如果您感兴趣,请点击此处了解有关魔术方法的更多信息Python的魔术方法指南(

由于您没有在BasicBlock中定义__call__方法(您定义了不同的call(,因此将使用tensorflow.keras.layers.Layer中的__call__

根据Tensorflow文档,该方法具有以下文档

Wraps调用,应用预处理和后处理步骤。

粗略地说,你会有(如果你感兴趣,你可以检查源代码,但它要复杂得多(:

class Layer(...):
....
def __call__(self, ...):
# preprocessing steps
self.call(...)
# post processing steps

如果您熟悉继承,您应该猜测使用basic_block(inputs):时的不同步骤

  1. 检查BasicBlock是否具有名为__call__=>否
  2. 检查基类CCD_ 11是否具有名为CCD_;是的,使用它并进入这个方法
  3. 应用预处理步骤
  4. 检查CCD_ 13是否具有名为CCD_;是,使用它并将其应用于输入
  5. 应用后期处理步骤

关于实现最佳资源的call方法的要求,请参阅Tensorflow官方文档,在该文档中,您已经解释了预期输入数据结构+关键字参数的所有内容

调用函数的使用方式如下:

basic_block = BasicBlock()
basic_block(args)

所以它来代替:

basic_block.call(args)

在TensorFlow中,当您将模型实例作为带有输入数据的函数调用时,会自动调用自定义模型的call((方法。这通常使用圆括号((运算符来完成,这是调用call((方法的简写。

例如,在问题中提供的代码中:

basicblock = BasicBlock(filter_num)
basicblock(input_data) # this will invoke the call method 

它看起来像这样:

BasicBlock(filter_num)(input_data)

最新更新