From 14d3d29b169871b23e25fab3fa2f3944dda97d01 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 30 Dec 2021 09:08:30 -0800 Subject: [PATCH] make ProcessException pickleable (#70118) Summary: Fixes https://github.com/pytorch/pytorch/issues/70116 Happy to add tests if you let me know the best place to put them. cc VitalyFedyunin Pull Request resolved: https://github.com/pytorch/pytorch/pull/70118 Reviewed By: malfet Differential Revision: D33255899 Pulled By: ejguan fbshipit-source-id: 41d495374182eb28bb8bb421e890eca3bddc077b --- test/test_multiprocessing_spawn.py | 11 +++++++++++ torch/multiprocessing/spawn.py | 10 ++++++++++ 2 files changed, 21 insertions(+) diff --git a/test/test_multiprocessing_spawn.py b/test/test_multiprocessing_spawn.py index f085a5aed6c..7e6f8142839 100644 --- a/test/test_multiprocessing_spawn.py +++ b/test/test_multiprocessing_spawn.py @@ -1,6 +1,7 @@ # Owner(s): ["module: multiprocessing"] import os +import pickle import random import signal import sys @@ -218,5 +219,15 @@ class SpawnTest(TestCase, _TestMultiProcessing): class ForkTest(TestCase, _TestMultiProcessing): start_method = 'fork' + +class ErrorTest(TestCase): + def test_errors_pickleable(self): + for error in ( + mp.ProcessRaisedException("Oh no!", 1, 1), + mp.ProcessExitedException("Oh no!", 1, 1, 1), + ): + pickle.loads(pickle.dumps(error)) + + if __name__ == '__main__': run_tests() diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 984f5dfc8f3..46df2d835e0 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -14,9 +14,13 @@ class ProcessException(Exception): def __init__(self, msg: str, error_index: int, pid: int): super().__init__(msg) + self.msg = msg self.error_index = error_index self.pid = pid + def __reduce__(self): + return type(self), (self.msg, self.error_index, self.pid) + class ProcessRaisedException(ProcessException): """ @@ -47,6 +51,12 @@ class ProcessExitedException(ProcessException): self.exit_code = exit_code self.signal_name = signal_name + def __reduce__(self): + return ( + type(self), + (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name), + ) + def _wrap(fn, i, args, error_queue): # prctl(2) is a Linux specific system call.