跟随 Keras 函数(预测)在同步调用时工作
pred = model.predict(x)
但是,从异步任务队列(Celery)中调用时,它不起作用。 Keras 预测函数在异步调用时不返回任何输出。
堆栈是:Django,Celery,Redis,Keras,TensorFlow
我遇到了完全相同的问题,伙计,这是一个兔子洞。想在这里发布我的解决方案,因为它可以为某人节省一天的工作:
TensorFlow 线程特定的数据结构
在 TensorFlow中,当您调用model.predict
(或keras.models.load_model
,或keras.backend.clear_session
,或几乎任何其他与 TensorFlow 后端交互的函数)时,有两个关键的数据结构在幕后工作:
- 一个 TensorFlow 图,它代表 Keras 模型的结构
- TensorFlow 会话,它是当前图形和 TensorFlow 运行时之间的连接
在没有挖掘的情况下,文档中没有明确清楚的是,会话和图形都是当前线程的属性。在此处和此处查看 API 文档。
在不同的线程中使用TensorFlow模型
很自然地想要加载一次模型,然后在以后多次调用.predict()
:
from keras.models import load_model
MY_MODEL = load_model('path/to/model/file')
def some_worker_function(inputs):
return MY_MODEL.predict(inputs)
在 Web 服务器或工作线程池上下文(如 Celery)中,这意味着您将在导入包含load_model
行的模块时加载模型,然后另一个线程将执行some_worker_function
,在包含 Keras 模型的全局变量上运行预测。但是,尝试在不同线程中加载的模型上运行预测会产生"张量不是此图的元素"错误。感谢几个涉及这个主题的SO帖子,例如ValueError:Tensor Tensor(...)不是这个图的一个元素。使用全局变量 keras 模型时。为了使它工作,你需要坚持使用的TensorFlow图——正如我们之前看到的,图是当前线程的属性。更新后的代码如下所示:
from keras.models import load_model
import tensorflow as tf
MY_MODEL = load_model('path/to/model/file')
MY_GRAPH = tf.get_default_graph()
def some_worker_function(inputs):
with MY_GRAPH.as_default():
return MY_MODEL.predict(inputs)
这里有点令人惊讶的转折是:如果您使用Thread
s,上面的代码就足够了,但如果您使用Process
es,则无限期挂起。默认情况下,Celery 使用进程来管理其所有工作线程池。所以在这一点上,芹菜的事情仍然不起作用。
为什么这仅适用于Thread
?
在 Python 中,Thread
与父进程共享相同的全局执行上下文。来自 Python _thread 文档:
该模块提供了用于处理多个线程(也称为轻量级进程或任务)的低级基元 - 多个控制线程共享其全局数据空间。
由于线程不是实际的独立进程,因此它们使用相同的python解释器,因此受到臭名昭著的全局交互锁(GIL)的约束。也许对于这项调查更重要的是,它们与父级共享全局数据空间。
与此相反,Process
es是程序产生的实际新进程。这意味着:
- 新的 Python 解释器实例(没有 GIL)
- 全局地址空间重复
请注意此处的区别。虽然Thread
可以访问共享的单个全局会话变量(存储在 Keras 的tensorflow_backend
模块内部),但Process
都有会话变量的副本。
我对这个问题的最佳理解是,Session 变量应该表示客户端(进程)和 TensorFlow 运行时之间的唯一连接,但由于在分叉过程中被复制,此连接信息没有得到适当的调整。这会导致 TensorFlow 在尝试使用在不同进程中创建的会话时挂起。如果有人对 TensorFlow 中的工作原理有更多了解,我很想听听!
解决方案/解决方法
我调整了芹菜,以便它使用Thread
s 而不是Process
es 进行池化。这种方法有一些缺点(请参阅上面的 GIL 注释),但这允许我们只加载一次模型。无论如何,我们并没有真正受到CPU的限制,因为TensorFlow运行时会最大化所有CPU内核(它可以避开GIL,因为它不是用Python编写的)。您必须为 Celery 提供一个单独的库来执行基于线程的池化;文档建议两个选项:gevent
或eventlet
。然后,通过--pool
命令行参数将所选库传递到工作线程中。
或者,似乎(正如您已经发现的那样@pX0r)其他 Keras 后端(如 Theano)没有这个问题。这是有道理的,因为这些问题与TensorFlow实现细节密切相关。我个人还没有尝试过Theano,所以你的里程可能会有所不同。
我知道这个问题是不久前发布的,但问题仍然存在,所以希望这会对某人有所帮助!
我从这个博客中得到了参考
- Tensorflow 是特定于线程的数据结构,当您调用model时,它们在幕后工作。
GRAPH = tf.get_default_graph() with GRAPH.as_default(): pred = model.predict return pred
但 Celery 使用流程来管理其所有工作线程池。所以在这一点上,事情仍然无法在 Celery 上运行,因为您需要使用 gevent 或 eventlet 库
点安装gevent
现在运行芹菜作为:
芹菜 -我的网站工人 --池 gevent -l 信息