在多个线程中重用 Tensorflow 会话会导致崩溃



背景:

我有一些复杂的强化学习算法,我想在多个线程中运行。

问题

尝试在线程中调用sess.run时,我收到以下错误消息:

RuntimeError: The Session graph is empty. Add operations to the graph before calling run().

重现错误的代码:

import tensorflow as tf
import threading
def thread_function(sess, i):
inn = [1.3, 4.5]
A = tf.placeholder(dtype=float, shape=(None), name="input")
P = tf.Print(A, [A])
Q = tf.add(A, P)
sess.run(Q, feed_dict={A: inn})
def main(sess):
thread_list = []
for i in range(0, 4):
t = threading.Thread(target=thread_function, args=(sess, i))
thread_list.append(t)
t.start()
for t in thread_list:
t.join()
if __name__ == '__main__':
sess = tf.Session()
main(sess)

如果我在线程外部运行相同的代码,它可以正常工作。

有人可以就如何在python线程中正确使用Tensorflow会话提供一些见解吗?

会话不仅可以是当前线程默认值,还可以是图形。 当您传入会话并对其调用run时,默认图形将是不同的图形。

您可以像这样修改thread_function以使其正常工作:

def thread_function(sess, i):
with sess.graph.as_default():
inn = [1.3, 4.5]
A = tf.placeholder(dtype=float, shape=(None), name="input")
P = tf.Print(A, [A])
Q = tf.add(A, P)
sess.run(Q, feed_dict={A: inn})

但是,我不希望有任何显着的加速。Python 线程在其他一些语言中并不是它的意思,只有某些操作(如 io(会并行运行。对于 CPU 密集型操作,它不是很有用。多处理可以真正并行运行代码,但您不会共享同一个会话。

用 github 上的另一个资源扩展 de1 的答案: 张量流/张量流#28287 (评论(

以下内容为我解决了 tf 的多线程兼容性:

# on thread 1
session = tf.Session(graph=tf.Graph())
with session.graph.as_default():
k.backend.set_session(session)
model = k.models.load_model(filepath)
# on thread 2
with session.graph.as_default():
k.backend.set_session(session)
model.predict(x)

这会保留其他线程的SessionGraph
模型加载到它们的"上下文"(而不是默认上下文(中,并保留以供其他线程使用。
(默认情况下,模型加载到默认Session和默认Graph(
另一个优点是它们保存在同一个对象中 - 更容易处理。

最新更新