我正在尝试创建一个 Web 应用程序,该应用程序接收来自 POST 请求的输入,并根据该输入提供一些 ML 预测。
由于预测模型非常繁重,我不希望用户等待计算完成。相反,我将繁重的计算委托给 Celery 任务,用户稍后可以检查结果。
我正在使用简单的Flask应用程序与芹菜,Redis和Flower。
我的view.py
:
@ns.route('predict/')
class Predict(Resource):
...
def post(self):
...
do_categorize(data)
return jsonify(success=True)
我的tasks.py
文件看起来像这样:
from ai.categorizer import Categorizer
categorizer = Categorizer(
model_path='category_model.h5',
tokenizer_path='tokenize.joblib',
labels_path='labels.joblib'
)
@task()
def do_categorize(data):
result = categorizer.predict(data)
print(result)
# Write result to the DB
...
我在类Categorizer
predict()
方法:
def predict(self, value):
K.set_session(self.sess)
with self.sess.as_default():
with self.graph.as_default():
prediction = self.model.predict(np.asarray([value], dtype='int64'))
return prediction
我像这样运行芹菜:
celery worker -A app.celery --loglevel=DEBUG
过去几天我遇到的问题是categorizer.predict(data)
调用在执行过程中挂起。
我尝试在 post 方法中运行categorizer.predict(data)
并且可以工作。但是如果我把它放在芹菜任务中,它就会停止工作。没有控制台日志,如果我尝试调试它,它只会冻结在.predict()
上。
我的问题:
- 如何解决此问题?
- 工作线程是否有任何内存、CPU 限制?
- 芹菜任务是进行如此繁重计算的"正确"方法吗?
- 如何调试此问题?我做错了什么?
- 在文件顶部初始化模型是否正确?
多亏了这个SO问题,我找到了问题的答案:
事实证明,Keras最好使用Threads
池而不是默认Process
。
幸运的是,不久前,随着芹菜 4.4Threads
池化重新引入。 您可以在 Celery 4.4 更新日志中阅读更多内容:
线程任务池
我们重新引入了一个线程任务池,使用 concurrent.futures.ThreadPoolExecutor.
以前的线程任务池是实验性的。此外,它基于已过时的线程池包。
您可以通过将worker_pool设置为"线程"或将 –pool 线程传递给 celery worker 命令来使用新的线程任务池。
现在,您可以使用线程而不是进程进行池化。
celery worker -A your_application --pool threads --loginfo=INFO
如果您不能使用最新的 Celery 版本,您可以使用gevent
包:
pip install gevent
celery worker -A your_application --pool gevent --loginfo=INFO