如何处理多进程中的错误?
How to handle errors in multiprocesses?
MWE在下面。
我的代码使用 torch.multiprocessing.Pool
生成进程,我使用 JoinableQueue
管理与 parent 的通信。我遵循了一些在线指南来优雅地处理 CTRL+C
。一切正常。但是,在某些情况下(我的代码比 MWE 有更多的东西),我在函数 运行 中遇到 children (online_test()
) 的错误。如果发生这种情况,代码将永远挂起,因为 children 不会通知 parent 发生了某些事情。我尝试在 children 主循环中添加 try ... except ... finally
,在 finally
中添加 queue.task_done()
,但没有任何改变。
我需要 parent 收到有关任何 children 错误的通知,并优雅地终止一切。我怎么能那样做?谢谢!
编辑
不起作用。处理程序捕获了异常,但主要代码仍然挂起,因为它等待队列为空。
import signal
import numpy as np
import multiprocessing as mp
STOP = 'STOP'
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def error_handler(exception):
print(f'{exception} occurred, terminating pool.')
pool.terminate()
def online_test(queue):
while True:
epoch, data_id = queue.get()
if data_id == STOP:
print(f'... test function is stopping ...')
break
print(f'testing function for {data_id} has started for epoch {epoch}')
if epoch == 1:
raise NotImplementedError
queue.task_done()
if __name__ == '__main__':
mp.set_start_method('spawn')
manager = mp.Manager()
test_queue = manager.JoinableQueue()
pool = mp.Pool(initializer=initializer)
pool.apply_async(online_test,
args=(test_queue,), error_callback=error_handler)
for i in ['a', 'b', 'c']:
test_queue.put((0, i))
try:
for epoch in range(10):
print('training epoch', epoch)
print('... waiting for testing before moving on to next epoch ...')
test_queue.join()
print(f'... epoch {epoch} testing is done')
for i in ['a', 'b', 'c']:
test_queue.put((epoch + 1, i))
for i in ['a', 'b', 'c']:
test_queue.put((-1, STOP))
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()
通过在错误处理函数中执行以下两件事,我让您编辑的代码可以正常工作:
- 清空
test_queue
。
- 将名为
aborted
的全局标志变量设置为 true 以指示处理应该停止。
然后在 __main__
过程中,我添加了代码来检查 aborted
标志,然后等待前一个纪元完成并开始另一个纪元。
使用 global
似乎有点 hacky,但它之所以有效,是因为错误处理函数是作为主进程的一部分执行的,因此可以访问其全局变量。我记得当我在处理链接的答案时突然想到这个细节——正如你所看到的——它可以证明 important/useful.
import signal
import numpy as np
import multiprocessing as mp
STOP = 'STOP'
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def error_handler(exception):
print(f'{exception=} occurred, terminating pool.')
pool.terminate()
print('pool terminated.')
while not test_queue.empty():
try:
test_queue.task_done()
except ValueError:
break
print(f'test_queue cleaned.')
global aborted
aborted = True # Indicate an error occurred to the main process.
def online_test(queue):
while True:
epoch, data_id = queue.get()
if data_id == STOP:
print(f'... test function is stopping ...')
break
print(f'testing function for {data_id} has started for epoch {epoch}')
if epoch == 1:
raise NotImplementedError('epoch == 1') # Fake error for testing.
queue.task_done()
if __name__ == '__main__':
aborted = False
mp.set_start_method('spawn')
manager = mp.Manager()
test_queue = manager.JoinableQueue()
pool = mp.Pool(initializer=initializer)
pool.apply_async(online_test, args=(test_queue,), error_callback=error_handler)
for i in ['a', 'b', 'c']:
test_queue.put((0, i))
try:
for epoch in range(10):
if aborted: # Error occurred?
print('ABORTED by error_handler!')
break
print('training epoch', epoch)
print('... waiting for testing before moving on to next epoch ...')
test_queue.join()
print(f'... epoch {epoch} testing is done')
for i in ['a', 'b', 'c']:
test_queue.put((epoch + 1, i))
for i in ['a', 'b', 'c']:
test_queue.put((-1, STOP))
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()
样本运行的输出:
training epoch 0
... waiting for testing before moving on to next epoch ...
testing function for a has started for epoch 0
testing function for b has started for epoch 0
testing function for c has started for epoch 0
... epoch 0 testing is done
testing function for a has started for epoch 1
training epoch 1
... waiting for testing before moving on to next epoch ...
exception=NotImplementedError('epoch == 1') occurred, terminating pool.
pool terminated.
... epoch 1 testing is done
test_queue cleaned.
ABORTED by error_handler!
Press any key to continue . . .
MWE在下面。
我的代码使用 torch.multiprocessing.Pool
生成进程,我使用 JoinableQueue
管理与 parent 的通信。我遵循了一些在线指南来优雅地处理 CTRL+C
。一切正常。但是,在某些情况下(我的代码比 MWE 有更多的东西),我在函数 运行 中遇到 children (online_test()
) 的错误。如果发生这种情况,代码将永远挂起,因为 children 不会通知 parent 发生了某些事情。我尝试在 children 主循环中添加 try ... except ... finally
,在 finally
中添加 queue.task_done()
,但没有任何改变。
我需要 parent 收到有关任何 children 错误的通知,并优雅地终止一切。我怎么能那样做?谢谢!
编辑
import signal
import numpy as np
import multiprocessing as mp
STOP = 'STOP'
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def error_handler(exception):
print(f'{exception} occurred, terminating pool.')
pool.terminate()
def online_test(queue):
while True:
epoch, data_id = queue.get()
if data_id == STOP:
print(f'... test function is stopping ...')
break
print(f'testing function for {data_id} has started for epoch {epoch}')
if epoch == 1:
raise NotImplementedError
queue.task_done()
if __name__ == '__main__':
mp.set_start_method('spawn')
manager = mp.Manager()
test_queue = manager.JoinableQueue()
pool = mp.Pool(initializer=initializer)
pool.apply_async(online_test,
args=(test_queue,), error_callback=error_handler)
for i in ['a', 'b', 'c']:
test_queue.put((0, i))
try:
for epoch in range(10):
print('training epoch', epoch)
print('... waiting for testing before moving on to next epoch ...')
test_queue.join()
print(f'... epoch {epoch} testing is done')
for i in ['a', 'b', 'c']:
test_queue.put((epoch + 1, i))
for i in ['a', 'b', 'c']:
test_queue.put((-1, STOP))
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()
通过在错误处理函数中执行以下两件事,我让您编辑的代码可以正常工作:
- 清空
test_queue
。 - 将名为
aborted
的全局标志变量设置为 true 以指示处理应该停止。
然后在 __main__
过程中,我添加了代码来检查 aborted
标志,然后等待前一个纪元完成并开始另一个纪元。
使用 global
似乎有点 hacky,但它之所以有效,是因为错误处理函数是作为主进程的一部分执行的,因此可以访问其全局变量。我记得当我在处理链接的答案时突然想到这个细节——正如你所看到的——它可以证明 important/useful.
import signal
import numpy as np
import multiprocessing as mp
STOP = 'STOP'
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def error_handler(exception):
print(f'{exception=} occurred, terminating pool.')
pool.terminate()
print('pool terminated.')
while not test_queue.empty():
try:
test_queue.task_done()
except ValueError:
break
print(f'test_queue cleaned.')
global aborted
aborted = True # Indicate an error occurred to the main process.
def online_test(queue):
while True:
epoch, data_id = queue.get()
if data_id == STOP:
print(f'... test function is stopping ...')
break
print(f'testing function for {data_id} has started for epoch {epoch}')
if epoch == 1:
raise NotImplementedError('epoch == 1') # Fake error for testing.
queue.task_done()
if __name__ == '__main__':
aborted = False
mp.set_start_method('spawn')
manager = mp.Manager()
test_queue = manager.JoinableQueue()
pool = mp.Pool(initializer=initializer)
pool.apply_async(online_test, args=(test_queue,), error_callback=error_handler)
for i in ['a', 'b', 'c']:
test_queue.put((0, i))
try:
for epoch in range(10):
if aborted: # Error occurred?
print('ABORTED by error_handler!')
break
print('training epoch', epoch)
print('... waiting for testing before moving on to next epoch ...')
test_queue.join()
print(f'... epoch {epoch} testing is done')
for i in ['a', 'b', 'c']:
test_queue.put((epoch + 1, i))
for i in ['a', 'b', 'c']:
test_queue.put((-1, STOP))
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()
样本运行的输出:
training epoch 0
... waiting for testing before moving on to next epoch ...
testing function for a has started for epoch 0
testing function for b has started for epoch 0
testing function for c has started for epoch 0
... epoch 0 testing is done
testing function for a has started for epoch 1
training epoch 1
... waiting for testing before moving on to next epoch ...
exception=NotImplementedError('epoch == 1') occurred, terminating pool.
pool terminated.
... epoch 1 testing is done
test_queue cleaned.
ABORTED by error_handler!
Press any key to continue . . .