MWE如下。
我的代码与torch.multiprocessing.Pool
产生进程,我用JoinableQueue
管理与父进程的通信。我遵循一些在线指南来优雅地处理CTRL+C
。一切正常。但是,在某些情况下(我的代码比MWE拥有更多的东西),我在子进程(online_test()
)运行的函数中遇到了错误。如果发生这种情况,代码就会永远挂起,因为子进程不会通知父进程发生了什么。我试着在主子循环中添加try ... except ... finally
,在finally
中添加queue.task_done()
,但没有任何改变。
我需要父进程收到关于任何子进程错误的通知,并优雅地终止一切。我怎么能这么做呢?谢谢!
编辑
建议的解决方案不起作用。处理程序捕获异常,但主代码仍然挂起,因为它等待队列为空。
import signal
import numpy as np
import multiprocessing as mp
STOP = 'STOP'
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def error_handler(exception):
print(f'{exception} occurred, terminating pool.')
pool.terminate()
def online_test(queue):
while True:
epoch, data_id = queue.get()
if data_id == STOP:
print(f'... test function is stopping ...')
break
print(f'testing function for {data_id} has started for epoch {epoch}')
if epoch == 1:
raise NotImplementedError
queue.task_done()
if __name__ == '__main__':
mp.set_start_method('spawn')
manager = mp.Manager()
test_queue = manager.JoinableQueue()
pool = mp.Pool(initializer=initializer)
pool.apply_async(online_test,
args=(test_queue,), error_callback=error_handler)
for i in ['a', 'b', 'c']:
test_queue.put((0, i))
try:
for epoch in range(10):
print('training epoch', epoch)
print('... waiting for testing before moving on to next epoch ...')
test_queue.join()
print(f'... epoch {epoch} testing is done')
for i in ['a', 'b', 'c']:
test_queue.put((epoch + 1, i))
for i in ['a', 'b', 'c']:
test_queue.put((-1, STOP))
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()
我得到的代码在你的编辑工作通过做这两件事在错误处理函数:
- 清空
test_queue
. - 设置全局标志变量
aborted
为true,表示停止处理。
然后在__main__
进程中,我添加了在等待前一个epoch完成并开始另一个epoch之前检查aborted
标志的代码。
使用global
似乎有点粗糙,但它是有效的,因为错误处理函数是作为主进程的一部分执行的,所以可以访问它的全局变量。我记得当我在做链接的答案时,我突然想到了这个细节——正如你所看到的——它可以证明是重要的/有用的。
import signal
import numpy as np
import multiprocessing as mp
STOP = 'STOP'
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def error_handler(exception):
print(f'{exception=} occurred, terminating pool.')
pool.terminate()
print('pool terminated.')
while not test_queue.empty():
try:
test_queue.task_done()
except ValueError:
break
print(f'test_queue cleaned.')
global aborted
aborted = True # Indicate an error occurred to the main process.
def online_test(queue):
while True:
epoch, data_id = queue.get()
if data_id == STOP:
print(f'... test function is stopping ...')
break
print(f'testing function for {data_id} has started for epoch {epoch}')
if epoch == 1:
raise NotImplementedError('epoch == 1') # Fake error for testing.
queue.task_done()
if __name__ == '__main__':
aborted = False
mp.set_start_method('spawn')
manager = mp.Manager()
test_queue = manager.JoinableQueue()
pool = mp.Pool(initializer=initializer)
pool.apply_async(online_test, args=(test_queue,), error_callback=error_handler)
for i in ['a', 'b', 'c']:
test_queue.put((0, i))
try:
for epoch in range(10):
if aborted: # Error occurred?
print('ABORTED by error_handler!')
break
print('training epoch', epoch)
print('... waiting for testing before moving on to next epoch ...')
test_queue.join()
print(f'... epoch {epoch} testing is done')
for i in ['a', 'b', 'c']:
test_queue.put((epoch + 1, i))
for i in ['a', 'b', 'c']:
test_queue.put((-1, STOP))
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()
示例运行的输出:
training epoch 0
... waiting for testing before moving on to next epoch ...
testing function for a has started for epoch 0
testing function for b has started for epoch 0
testing function for c has started for epoch 0
... epoch 0 testing is done
testing function for a has started for epoch 1
training epoch 1
... waiting for testing before moving on to next epoch ...
exception=NotImplementedError('epoch == 1') occurred, terminating pool.
pool terminated.
... epoch 1 testing is done
test_queue cleaned.
ABORTED by error_handler!
Press any key to continue . . .