diff --git a/test/test_multiprocessing_spawn.py b/test/test_multiprocessing_spawn.py index f085a5aed6c..7e6f8142839 100644 --- a/test/test_multiprocessing_spawn.py +++ b/test/test_multiprocessing_spawn.py @@ -1,6 +1,7 @@ # Owner(s): ["module: multiprocessing"] import os +import pickle import random import signal import sys @@ -218,5 +219,15 @@ class SpawnTest(TestCase, _TestMultiProcessing): class ForkTest(TestCase, _TestMultiProcessing): start_method = 'fork' + +class ErrorTest(TestCase): + def test_errors_pickleable(self): + for error in ( + mp.ProcessRaisedException("Oh no!", 1, 1), + mp.ProcessExitedException("Oh no!", 1, 1, 1), + ): + pickle.loads(pickle.dumps(error)) + + if __name__ == '__main__': run_tests() diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 984f5dfc8f3..46df2d835e0 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -14,9 +14,13 @@ class ProcessException(Exception): def __init__(self, msg: str, error_index: int, pid: int): super().__init__(msg) + self.msg = msg self.error_index = error_index self.pid = pid + def __reduce__(self): + return type(self), (self.msg, self.error_index, self.pid) + class ProcessRaisedException(ProcessException): """ @@ -47,6 +51,12 @@ class ProcessExitedException(ProcessException): self.exit_code = exit_code self.signal_name = signal_name + def __reduce__(self): + return ( + type(self), + (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name), + ) + def _wrap(fn, i, args, error_queue): # prctl(2) is a Linux specific system call.