Keras 使用 tf.data.Dataset 预测循环内存泄漏,但不使用 numpy 数组



在使用tf.data.Dataset馈送模型时循环遍历 Keras 模型predict时遇到内存泄漏和性能下降,但在使用 numpy 数组馈送模型时则不然。

有谁了解导致此问题的原因和/或如何解决问题?

最少的可重现代码片段(可复制/粘贴运行):

import tensorflow as tf
import numpy as np
import time
SIZE = 5000
inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)
model = tf.keras.Model(inputs=inp, outputs=x)
np_data = np.random.rand(1, SIZE)
ds = tf.data.Dataset.from_tensor_slices(np_data).batch(1).repeat()
debug_time = time.time()
while True:
model.predict(x=ds, steps=1)
print('Processing {:.2f}'.format(time.time() - debug_time))
debug_time = time.time()

结果:预测循环计时每次迭代大约 0.04 秒,一两分钟内达到约 0.5 秒,进程内存继续从几百 MB 增加到接近 GB。


tf.data.Dataset换成等效的 numpy 数组,运行时间始终为 ~0.01 秒。

工作案例代码片段(可复制/粘贴运行):

import tensorflow as tf
import numpy as np
import time
SIZE = 5000
inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)
model = tf.keras.Model(inputs=inp, outputs=x)
np_data = np.random.rand(1, SIZE)
debug_time = time.time()
while True:
model.predict(x=np_data)  # using numpy array directly
print('Processing {:.2f}'.format(time.time() - debug_time))
debug_time = time.time()

相关讨论:

  • 内存泄漏tf.data + Keras - 似乎没有解决核心问题,但问题似乎相似。
  • https://github.com/tensorflow/tensorflow/issues/22098 - 可能是 Keras/Github 中的一个未解决的问题,但我无法确认,按照该线程中的建议更改inter_op_paralellism对此处发布的结果没有影响。

附加信息:

  • 我可以通过传入迭代器而不是数据集对象将性能下降的速度降低大约 10 倍。我注意到在training_utils.py:1314Keras 代码正在创建一个迭代器来预测每次调用。

TF 1.14.0

问题的根源似乎是 Keras 在每个predict循环中创建数据集操作。请注意,training_utils.py:1314在每个预测循环中创建一个数据集迭代器。

这个问题可以通过传入迭代器来降低严重性,并且可以通过传入迭代器get_next()张量来完全解决。

我已经在Tensorflow Github页面上发布了这个问题:https://github.com/tensorflow/tensorflow/issues/30448

这是解决方案,此示例使用 TF 数据集在恒定时间内运行,您只是无法传入数据集对象:

import tensorflow as tf
import numpy as np
import time
SIZE = 5000
inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)
model = tf.keras.Model(inputs=inp, outputs=x)
np_data = np.random.rand(1, SIZE)
ds = tf.data.Dataset.from_tensor_slices(np_data).batch(1).repeat()
it = tf.data.make_one_shot_iterator(ds)
tensor = it.get_next()
debug_time = time.time()
while True:
model.predict(x=tensor, steps=1)
print('Processing {:.2f}'.format(time.time() - debug_time))
debug_time = time.time()

最新更新