对 Keras 网络使用多处理时获取错误"The Session graph is empty."



我想评估并得到每个样本的损失。所以我想用多处理来加速它。但是它显示错误"会话图是空的。在调用run()之前向图中添加操作。">

model.fit(x=X_measured, y=y_train, batch_size=batch_size, epochs=epochs, verbose=0, 
validation_data=(X_measured_test,y_test), shuffle=True)

def get_loss(i, model, X_measured, y_train):
samples_loss=model.evaluate(x=X_measured[i:i+1,:],y=y_train[i:i+1,:],batch_size=None,verbose=0,steps=1)
return samples_loss

pool = mp.Pool(mp.cpu_count())
samples_loss=pool.starmap(get_loss, [(j, model, X_measured, y_train) for j in range(X_measured.shape[0])])
pool.close()

根据这个关于Keras中多处理的优秀答案,最好的经验法则是"在单独的进程中运行与每个模型相关的工作"。

因此,你构建事物的方式——在主环境上训练模型,然后在单独的进程上计算损失——不能完成,因为Keras/Tensorflow将一堆东西加载/配置到主进程中,而不会向下传播到派生进程。

如果你有单独的模型要训练,似乎最好的方法是为每个模型的训练和评估生成一个新的过程。

最新更新