make ProcessException pickleable (#70118)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/70116

Happy to add tests if you let me know the best place to put them.

cc VitalyFedyunin

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70118

Reviewed By: malfet

Differential Revision: D33255899

Pulled By: ejguan

fbshipit-source-id: 41d495374182eb28bb8bb421e890eca3bddc077b
This commit is contained in:
epwalsh 2021-12-30 09:08:30 -08:00 committed by Facebook GitHub Bot
parent 9c742bea59
commit 14d3d29b16
2 changed files with 21 additions and 0 deletions

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: multiprocessing"]
import os
import pickle
import random
import signal
import sys
@ -218,5 +219,15 @@ class SpawnTest(TestCase, _TestMultiProcessing):
class ForkTest(TestCase, _TestMultiProcessing):
start_method = 'fork'
class ErrorTest(TestCase):
def test_errors_pickleable(self):
for error in (
mp.ProcessRaisedException("Oh no!", 1, 1),
mp.ProcessExitedException("Oh no!", 1, 1, 1),
):
pickle.loads(pickle.dumps(error))
if __name__ == '__main__':
run_tests()

View File

@ -14,9 +14,13 @@ class ProcessException(Exception):
def __init__(self, msg: str, error_index: int, pid: int):
super().__init__(msg)
self.msg = msg
self.error_index = error_index
self.pid = pid
def __reduce__(self):
return type(self), (self.msg, self.error_index, self.pid)
class ProcessRaisedException(ProcessException):
"""
@ -47,6 +51,12 @@ class ProcessExitedException(ProcessException):
self.exit_code = exit_code
self.signal_name = signal_name
def __reduce__(self):
return (
type(self),
(self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
)
def _wrap(fn, i, args, error_queue):
# prctl(2) is a Linux specific system call.