Keras Lambda层内存泄漏


# split the channels in two (first part for IN, second for BN)
x_in = Lambda(lambda x: x[:, :, :, :split_index])(x)
x_bn = Lambda(lambda x: x[:, :, :, split_index:])(x)
# apply IN and BN on their respective group of channels
x_in = InstanceNormalization(axis=3)(x_in)
x_bn = BatchNormalization(axis=3)(x_bn)
# concatenate outputs of IN and BN
x = Concatenate(axis=3)([x_in, x_bn])


Layer (type)                    Output Shape         Param #     Connected to
input_1 (InputLayer)            (None, 832, 832, 1)  0
conv1 (Conv2D)                  (None, 832, 832, 32) 320         input_1[0][0]
lambda_1 (Lambda)               (None, 832, 832, 16) 0           conv1[0][0]
lambda_2 (Lambda)               (None, 832, 832, 16) 0           conv1[0][0]
instance_normalization_1 (Insta (None, 832, 832, 16) 32          lambda_1[0][0]
batch_normalization_1 (BatchNor (None, 832, 832, 16) 64          lambda_2[0][0]
concatenate_1 (Concatenate)     (None, 832, 832, 32) 0           instance_normalization_1[0][0]


# apply IN and BN on the input tensor independently
x_in = InstanceNormalization(axis=3)(x)
x_bn = BatchNormalization(axis=3)(x)
# addition of the feature maps outputed by IN and BN
x = Add()([x_in, x_bn])

有解决内存泄漏的办法吗?我正在使用Keras 2.2.4和Tensorflow 1.15.3,目前无法升级到TF 2或TF.Keras。

Thibault Bacqueyrisses的答案是对的,内存泄漏通过自定义层消失了!


class Crop(keras.layers.Layer):
def __init__(self, dim, start, end, **kwargs):
Slice the tensor on the last dimension, keeping what is between start
and end.
dim   (int)   : dimension of the tensor (including the batch dim)
start (int)   : index of where to start the cropping
end   (int)   : index of where to stop the cropping
super(Crop, self).__init__(**kwargs)
self.dimension = dim
self.start = start
self.end = end
def call(self, inputs):
if self.dimension == 0:
return inputs[self.start:self.end]
if self.dimension == 1:
return inputs[:, self.start:self.end]
if self.dimension == 2:
return inputs[:, :, self.start:self.end]
if self.dimension == 3:
return inputs[:, :, :, self.start:self.end]
if self.dimension == 4:
return inputs[:, :, :, :, self.start:self.end]
def compute_output_shape(self, input_shape):
return (input_shape[:-1] + (self.end - self.start,))
def get_config(self):
config = {
'dim': self.dimension,
'start': self.start,
'end': self.end,
base_config = super(Crop, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

keras lambda层可能存在一些故障。
