同步使用'while True'循环运行的多个进程



我需要以下内容:

  • 主要功能训练模型
  • 在每个历元,它的参数都被复制到测试模型中
  • 测试模型用于在多个数据集上进行测试
  • 测试必须并行进行,同时为下一个时代继续训练
  • 等待在所有数据集上完成测试,然后再进入下一个训练时期
  • 测试函数报告一些统计数据,这些数据由主函数读取

以下代码使用单个Queue,只测试第一个数据集。我需要将其扩展到所有数据集。

import signal
import numpy as np
import multiprocessing as mp
STOP = -1
data = {'x': np.random.rand(), 'y': np.random.rand(), 'z': np.random.rand()}
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)

def online_test(p2c_queue, c2p_queue, data_id, model, shared_stats):
print(f'testing function for {data_id} has started')
while True: # keep process alive for testing
print(f'... {data_id} waiting ...')
epoch = p2c_queue.get()
if epoch == STOP:
print(f'... testing {data_id} is over, function is ending ...')
break
shared_stats.update({data_id: {k: [] for k in ['prediction', 'error']}})
print(f'... {data_id} evaluation ...')
pred = model.value
err = pred - data[data_id] # simplified version, the real one takes some time
# shared_stats.update({data_id: {'prediction': pred, 'error': err}})
shared_stats.update({data_id: {'prediction': epoch, 'error': -epoch}}) # to debug if order of calls is correct
c2p_queue.put(True) # notify parent that testing is done for requested epoch

if __name__ == '__main__':
stats = {**{'epoch': []},
**{data_id: {k: [] for k in ['prediction', 'error']} for data_id in data.keys()}}
mp.set_start_method('spawn')
manager = mp.Manager()
p2c_queue = manager.Queue() # parent-to-child: parent tells child to start testing
c2p_queue = manager.Queue() # child-to-parent: child tells parent that testing is done
test_model = manager.Value('d', 10.0)
shared_stats = manager.dict()
pool = mp.Pool(initializer=initializer)
p2c_queue.put(0) # testing can start for raw model
pool.apply_async(online_test,
args=(p2c_queue, c2p_queue, 'x', test_model, shared_stats))
try: # wrap all in a try-except to handle KeyboardInterrupt
for epoch in range(10):
print('training epoch', epoch)
# ... here I do some training and then copy my parameters to test_model
test_model.value = np.random.rand() # simplified version
print('... waiting for testing before moving on to next epoch ...')
if c2p_queue.get(): # keep training only if previous eval is done
print(f'... epoch {epoch} testing is done, stats are')
for data_id in shared_stats.keys(): # but first copy stats here
for k in stats[data_id].keys():
mu = np.mean(shared_stats[data_id][k])
stats[data_id][k].append(mu)
print('  ', data_id, k, mu)
p2c_queue.put(epoch + 1)
p2c_queue.put(STOP)
print(stats)
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()
  1. 如何同步多个进程?这里的示例仅为数据'x'生成一个。我尝试过:
  • 使用多个队列,但我的代码挂起。每个队列应该只有一个项目对应于测试数据集
  • 使用一个队列。队列中的项目应该与测试数据集的数量一样多。这个想法是检查队列何时为空,但我已经读到empty()是不可靠的
  1. 我需要锁吗?shared_stats由所有进程访问,可能同时访问,但每个进程只设置字典的特定键,所以这应该不是问题。对吧

我是用JoinableQueue完成的,代码如下。

然而,这个版本的测试方式与我最初计划的不同。在这里,任何进程都可以负责测试任何数据集,而我希望有一个进程测试始终相同的数据集。欢迎提出任何建议/意见。

import signal
import numpy as np
import multiprocessing as mp
STOP = 'STOP'
data = {'x': np.random.rand(), 'y': np.random.rand(), 'z': np.random.rand()}
debug_value = {'x': 1, 'y': 10, 'z': 100}
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)

def online_test(queue, model, shared_stats):
while True: # keep process alive for testing
epoch, data_id = queue.get()
if data_id == STOP:
print(f'... test function is stopping ...')
break
print(f'testing function for {data_id} has started for epoch {epoch}')
shared_stats.update({data_id: {k: [] for k in ['prediction', 'error']}})
# print(f'... evaluation ...')
pred = model.value
err = pred - data[data_id] # simplified version, the real one takes some time
checker = debug_value[data_id]
shared_stats.update({data_id: {'prediction': epoch * checker, 'error': - epoch * checker}}) # to debug if order of calls is correct
# shared_stats.update({data_id: {'prediction': pred, 'error': err}})
queue.task_done() # notify parent that testing is done for requested epoch

if __name__ == '__main__':
stats = {**{'epoch': []},
**{data_id: {k: [] for k in ['prediction', 'error']} for data_id in data.keys()}}
mp.set_start_method('spawn')
manager = mp.Manager()
test_queue = manager.JoinableQueue()
test_model = manager.Value('d', 10.0)
shared_stats = manager.dict()
pool = mp.Pool(initializer=initializer)
for data_id in data.keys():
pool.apply_async(online_test,
args=(test_queue, test_model, shared_stats))
test_queue.put((0, data_id)) # testing can start
try: # wrap all in a try-except to handle KeyboardInterrupt
for epoch in range(10):
print('training epoch', epoch)
# ... here I do some training and then copy my parameters to test_model
print('... waiting for testing before moving on to next epoch ...')
test_queue.join() # keep training only if previous eval is done
stats['epoch'].append(epoch + 1)
test_model.value = np.random.rand() # simplified version
print(f'... epoch {epoch} testing is done, stats are')
for data_id in shared_stats.keys(): # but first copy stats here
for k in stats[data_id].keys():
mu = np.mean(shared_stats[data_id][k])
stats[data_id][k].append(mu)
# print('  ', data_id, k, mu)
test_queue.put((epoch + 1, data_id))
for data_id in shared_stats.keys(): # notify all procs to end
test_queue.put((-1, STOP))
print(stats)
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()

最新更新