TensorFlow自定义层:获取实际批处理大小



我想实现一个自定义tf层,它执行涉及输入张量的实际批处理大小的数学运算:

import tensorflow as tf
from   tensorflow import keras
class MyLayer(keras.layers.Layer):
def build(self, input_shape):
self.batch_size = input_shape[0]
super().build(input_shape)
def call(self,input):
self.batch_size + 1 # do something with the batch size
return input

然而,当构建一个图时,它的值最初是None,这破坏了MyLayer的功能:

input = keras.Input(shape=(10,))
x     = MyLayer()(input)
TypeError: in user code:
<ipython-input-41-98e23e82198d>:11 call  *
self.batch_size + 1 # do something with the batch size
TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'

是否有任何方法使这些层工作后,模型已经建成?

使用tf.shape在你的call方法中获取批大小

:

import tensorflow as tf

# custom layer
class MyLayer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()

def call(self, x):
bs = tf.shape(x)[0]
return x, tf.add(bs, 1)


# network
x_in = tf.keras.Input(shape=(None, 10,))
x = MyLayer()(x_in)
# model def
model = tf.keras.models.Model(x_in, x)
# forward pass
_, shp = model(tf.random.normal([5, 10]))
# shape value
print(shp)
# tf.Tensor(6, shape=(), dtype=int32)

最新更新