线性内存增加与次数成正比模型.Fit叫做健康



我使用自定义模型的tensorflow来训练一些ai,我有记忆问题。我的模型由3个不同的层组成,每个层接收和处理不同的信息,这些信息的输出然后一起被馈送到一个密集的网络中,最终给我一些输出。

def call(self, x, training=False):
if training:
xs = []
for xx in x:
xx = tf.concat([self.com(xx[0], training), self.hand(xx[1], training), self.rec(xx[2], training)], 1)
xx = self.out_0(xx)
xx = self.out_1(xx)
xs.append(xx)
return tf.stack(xs)
else:
x = tf.concat([self.com(x[0], training), self.hand(x[1], training), self.rec(x[2], training)], 1)

x = self.out_0(x)
x = self.out_1(x)
return x

这就是我如何为我的模型编写调用函数。通常,x是一个由3个张量组成的列表,每个张量代表不同的输入,被送入不同的层。当训练时,我放入一个列表,其中列表中的每个元素都是我的正常输入,就像一个包含3个张量的列表。我是这样做的,而不是使用数据集,因为我找不到一种方法使它工作,除非我放弃批处理而不组合这3个张量,这是不可能的,由于形状的差异。

然而,由于某种原因,每当我在这个模型的实例上调用fit时,我的内存消耗就会异常地增加,并且不会消失。我必须多次调用fit,所以这是一个非常大的问题。

我怀疑这可能与我写的调用函数有关,特别是训练后的门控部分,因为当我以不批处理的代价调用数据集fit并删除由于不再需要的门控部分时,内存问题消失了。然而,我不知道为什么。

另外,我认为这可能与默认图形在某种程度上创建了一些东西,并且在调用fit时不清理它们有关,但我也不知道这种怀疑有多有效。

已解决。问题是我如何使用python列表作为输入调用fit。这导致了tf。函数在被调用的时候都是新建的,实际上,这使得内存消耗达到了惊人的水平。

相关内容

最新更新