python多进程在调用tensorflow模型时挂起



我正在尝试使用循环并行训练一些模型。但是,它挂起并且不继续使用.fit方法。

def crossover_during_training(pair):
print("FOR PAIR NUMBER " + str(pair))
for index in [1, 5, 10, 15]:
print("one")
model = model_keras(x_train, x_test, y_train, y_test, 0)
model_one = model_keras(x_train, x_test, y_train, y_test, pair)
model_two = model_keras(x_train, x_test, y_train, y_test, pair+100)

print(model, model_one, model_two)
print(x_train.shape)
model_information_parent_one = model_one.fit(x_train, y_train, epochs=index, 
batch_size=128, verbose=True, validation_data=(x_test, y_test)) 
print(model_information_parent_one)
weights_nn_one = model_one.get_weights()
model_information_parent_two = model_two.fit(x_train, y_train, epochs=index, 
batch_size=128, verbose=True, validation_data=(x_test, y_test)) 
weights_nn_two = model_two.get_weights()

print("two")

这就是我如何简单地使用多处理模块

all_args = [pair for pair in range(2)]
pool = Pool(2) 
results = pool.map(crossover_during_training, all_args)

crossover_during_training函数运行,但永远不会超过model.fit。换句话说,它永远不会到达print("two"(位。

在处理器中调用fit方法是否有任何错误?

我在docker容器中运行时遇到了同样的问题,@thijsvdp在他的评论中指出的解决方案对我有效:

multiprocessing.get_context('spawn').Pool(pool_size)

相关内容

  • 没有找到相关文章

最新更新