如何处理多进程中的错误?



MWE如下。
我的代码与torch.multiprocessing.Pool产生进程,我用JoinableQueue管理与父进程的通信。我遵循一些在线指南来优雅地处理CTRL+C。一切正常。但是,在某些情况下(我的代码比MWE拥有更多的东西),我在子进程(online_test())运行的函数中遇到了错误。如果发生这种情况,代码就会永远挂起,因为子进程不会通知父进程发生了什么。我试着在主子循环中添加try ... except ... finally,在finally中添加queue.task_done(),但没有任何改变。

我需要父进程收到关于任何子进程错误的通知,并优雅地终止一切。我怎么能这么做呢?谢谢!

编辑
建议的解决方案不起作用。处理程序捕获异常,但主代码仍然挂起,因为它等待队列为空。

import signal
import numpy as np
import multiprocessing as mp
STOP = 'STOP'
def initializer():
    """Ignore CTRL+C in the worker process."""
    signal.signal(signal.SIGINT, signal.SIG_IGN)
def error_handler(exception):
    print(f'{exception} occurred, terminating pool.')
    pool.terminate()
def online_test(queue):
    while True:
        epoch, data_id = queue.get()
        if data_id == STOP:
            print(f'... test function is stopping ...')
            break
        print(f'testing function for {data_id} has started for epoch {epoch}')
        if epoch == 1:
            raise NotImplementedError
        queue.task_done()

if __name__ == '__main__':
    mp.set_start_method('spawn')
    manager = mp.Manager()
    test_queue = manager.JoinableQueue()
    pool = mp.Pool(initializer=initializer)
    pool.apply_async(online_test,
        args=(test_queue,), error_callback=error_handler)
    for i in ['a', 'b', 'c']:
        test_queue.put((0, i))
    try:
        for epoch in range(10):
            print('training epoch', epoch)
            print('... waiting for testing before moving on to next epoch ...')
            test_queue.join()
            print(f'... epoch {epoch} testing is done')
            for i in ['a', 'b', 'c']:
                test_queue.put((epoch + 1, i))
        for i in ['a', 'b', 'c']:
            test_queue.put((-1, STOP))
    except KeyboardInterrupt:
        pool.terminate()
    else:
        pool.close()
    pool.join()

我得到的代码在你的编辑工作通过做这两件事在错误处理函数:

  1. 清空test_queue .
  2. 设置全局标志变量aborted为true,表示停止处理。

然后在__main__进程中,我添加了在等待前一个epoch完成并开始另一个epoch之前检查aborted标志的代码。

使用global似乎有点粗糙,但它是有效的,因为错误处理函数是作为主进程的一部分执行的,所以可以访问它的全局变量。我记得当我在做链接的答案时,我突然想到了这个细节——正如你所看到的——它可以证明是重要的/有用的。

import signal
import numpy as np
import multiprocessing as mp
STOP = 'STOP'
def initializer():
    """Ignore CTRL+C in the worker process."""
    signal.signal(signal.SIGINT, signal.SIG_IGN)
def error_handler(exception):
    print(f'{exception=} occurred, terminating pool.')
    pool.terminate()
    print('pool terminated.')
    while not test_queue.empty():
        try:
            test_queue.task_done()
        except ValueError:
            break
    print(f'test_queue cleaned.')
    global aborted
    aborted = True  # Indicate an error occurred to the main process.
def online_test(queue):
    while True:
        epoch, data_id = queue.get()
        if data_id == STOP:
            print(f'... test function is stopping ...')
            break
        print(f'testing function for {data_id} has started for epoch {epoch}')
        if epoch == 1:
            raise NotImplementedError('epoch == 1')  # Fake error for testing.
        queue.task_done()

if __name__ == '__main__':
    aborted = False
    mp.set_start_method('spawn')
    manager = mp.Manager()
    test_queue = manager.JoinableQueue()
    pool = mp.Pool(initializer=initializer)
    pool.apply_async(online_test,  args=(test_queue,), error_callback=error_handler)
    for i in ['a', 'b', 'c']:
        test_queue.put((0, i))
    try:
        for epoch in range(10):
            if aborted:  # Error occurred?
                print('ABORTED by error_handler!')
                break
            print('training epoch', epoch)
            print('... waiting for testing before moving on to next epoch ...')
            test_queue.join()
            print(f'... epoch {epoch} testing is done')
            for i in ['a', 'b', 'c']:
                test_queue.put((epoch + 1, i))
        for i in ['a', 'b', 'c']:
            test_queue.put((-1, STOP))
    except KeyboardInterrupt:
        pool.terminate()
    else:
        pool.close()
    pool.join()

示例运行的输出:

training epoch 0
... waiting for testing before moving on to next epoch ...
testing function for a has started for epoch 0
testing function for b has started for epoch 0
testing function for c has started for epoch 0
... epoch 0 testing is done
testing function for a has started for epoch 1
training epoch 1
... waiting for testing before moving on to next epoch ...
exception=NotImplementedError('epoch == 1') occurred, terminating pool.
pool terminated.
... epoch 1 testing is done
test_queue cleaned.
ABORTED by error_handler!
Press any key to continue . . .

相关内容

  • 没有找到相关文章

最新更新