如何处理多进程中的错误?

How to handle errors in multiprocesses?

MWE在下面。
我的代码使用 torch.multiprocessing.Pool 生成进程,我使用 JoinableQueue 管理与 parent 的通信。我遵循了一些在线指南来优雅地处理 CTRL+C。一切正常。但是,在某些情况下(我的代码比 MWE 有更多的东西),我在函数 运行 中遇到 children (online_test()) 的错误。如果发生这种情况,代码将永远挂起,因为 children 不会通知 parent 发生了某些事情。我尝试在 children 主循环中添加 try ... except ... finally,在 finally 中添加 queue.task_done(),但没有任何改变。

我需要 parent 收到有关任何 children 错误的通知,并优雅地终止一切。我怎么能那样做?谢谢!

编辑
不起作用。处理程序捕获了异常,但主要代码仍然挂起,因为它等待队列为空。

import signal
import numpy as np
import multiprocessing as mp

STOP = 'STOP'

def initializer():
    """Ignore CTRL+C in the worker process."""
    signal.signal(signal.SIGINT, signal.SIG_IGN)

def error_handler(exception):
    print(f'{exception} occurred, terminating pool.')
    pool.terminate()

def online_test(queue):
    while True:
        epoch, data_id = queue.get()
        if data_id == STOP:
            print(f'... test function is stopping ...')
            break
        print(f'testing function for {data_id} has started for epoch {epoch}')
        if epoch == 1:
            raise NotImplementedError
        queue.task_done()


if __name__ == '__main__':
    mp.set_start_method('spawn')
    manager = mp.Manager()
    test_queue = manager.JoinableQueue()
    pool = mp.Pool(initializer=initializer)
    pool.apply_async(online_test,
        args=(test_queue,), error_callback=error_handler)
    for i in ['a', 'b', 'c']:
        test_queue.put((0, i))

    try:
        for epoch in range(10):
            print('training epoch', epoch)
            print('... waiting for testing before moving on to next epoch ...')
            test_queue.join()
            print(f'... epoch {epoch} testing is done')
            for i in ['a', 'b', 'c']:
                test_queue.put((epoch + 1, i))

        for i in ['a', 'b', 'c']:
            test_queue.put((-1, STOP))

    except KeyboardInterrupt:
        pool.terminate()
    else:
        pool.close()

    pool.join()

通过在错误处理函数中执行以下两件事,我让您编辑的代码可以正常工作:

  1. 清空 test_queue
  2. 将名为 aborted 的全局标志变量设置为 true 以指示处理应该停止。

然后在 __main__ 过程中,我添加了代码来检查 aborted 标志,然后等待前一个纪元完成并开始另一个纪元。

使用 global 似乎有点 hacky,但它之所以有效,是因为错误处理函数是作为主进程的一部分执行的,因此可以访问其全局变量。我记得当我在处理链接的答案时突然想到这个细节——正如你所看到的——它可以证明 important/useful.

import signal
import numpy as np
import multiprocessing as mp

STOP = 'STOP'

def initializer():
    """Ignore CTRL+C in the worker process."""
    signal.signal(signal.SIGINT, signal.SIG_IGN)

def error_handler(exception):
    print(f'{exception=} occurred, terminating pool.')
    pool.terminate()
    print('pool terminated.')
    while not test_queue.empty():
        try:
            test_queue.task_done()
        except ValueError:
            break
    print(f'test_queue cleaned.')
    global aborted
    aborted = True  # Indicate an error occurred to the main process.

def online_test(queue):
    while True:
        epoch, data_id = queue.get()
        if data_id == STOP:
            print(f'... test function is stopping ...')
            break
        print(f'testing function for {data_id} has started for epoch {epoch}')
        if epoch == 1:
            raise NotImplementedError('epoch == 1')  # Fake error for testing.
        queue.task_done()


if __name__ == '__main__':
    aborted = False
    mp.set_start_method('spawn')
    manager = mp.Manager()
    test_queue = manager.JoinableQueue()
    pool = mp.Pool(initializer=initializer)
    pool.apply_async(online_test,  args=(test_queue,), error_callback=error_handler)
    for i in ['a', 'b', 'c']:
        test_queue.put((0, i))

    try:
        for epoch in range(10):
            if aborted:  # Error occurred?
                print('ABORTED by error_handler!')
                break
            print('training epoch', epoch)
            print('... waiting for testing before moving on to next epoch ...')
            test_queue.join()
            print(f'... epoch {epoch} testing is done')
            for i in ['a', 'b', 'c']:
                test_queue.put((epoch + 1, i))

        for i in ['a', 'b', 'c']:
            test_queue.put((-1, STOP))

    except KeyboardInterrupt:
        pool.terminate()
    else:
        pool.close()

    pool.join()

样本运行的输出:

training epoch 0
... waiting for testing before moving on to next epoch ...
testing function for a has started for epoch 0
testing function for b has started for epoch 0
testing function for c has started for epoch 0
... epoch 0 testing is done
testing function for a has started for epoch 1
training epoch 1
... waiting for testing before moving on to next epoch ...
exception=NotImplementedError('epoch == 1') occurred, terminating pool.
pool terminated.
... epoch 1 testing is done
test_queue cleaned.
ABORTED by error_handler!
Press any key to continue . . .