mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
make ProcessException pickleable (#70118)
Summary: Fixes https://github.com/pytorch/pytorch/issues/70116 Happy to add tests if you let me know the best place to put them. cc VitalyFedyunin Pull Request resolved: https://github.com/pytorch/pytorch/pull/70118 Reviewed By: malfet Differential Revision: D33255899 Pulled By: ejguan fbshipit-source-id: 41d495374182eb28bb8bb421e890eca3bddc077b
This commit is contained in:
parent
9c742bea59
commit
14d3d29b16
|
|
@ -1,6 +1,7 @@
|
||||||
# Owner(s): ["module: multiprocessing"]
|
# Owner(s): ["module: multiprocessing"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import random
|
import random
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -218,5 +219,15 @@ class SpawnTest(TestCase, _TestMultiProcessing):
|
||||||
class ForkTest(TestCase, _TestMultiProcessing):
|
class ForkTest(TestCase, _TestMultiProcessing):
|
||||||
start_method = 'fork'
|
start_method = 'fork'
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorTest(TestCase):
|
||||||
|
def test_errors_pickleable(self):
|
||||||
|
for error in (
|
||||||
|
mp.ProcessRaisedException("Oh no!", 1, 1),
|
||||||
|
mp.ProcessExitedException("Oh no!", 1, 1, 1),
|
||||||
|
):
|
||||||
|
pickle.loads(pickle.dumps(error))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,13 @@ class ProcessException(Exception):
|
||||||
|
|
||||||
def __init__(self, msg: str, error_index: int, pid: int):
|
def __init__(self, msg: str, error_index: int, pid: int):
|
||||||
super().__init__(msg)
|
super().__init__(msg)
|
||||||
|
self.msg = msg
|
||||||
self.error_index = error_index
|
self.error_index = error_index
|
||||||
self.pid = pid
|
self.pid = pid
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
return type(self), (self.msg, self.error_index, self.pid)
|
||||||
|
|
||||||
|
|
||||||
class ProcessRaisedException(ProcessException):
|
class ProcessRaisedException(ProcessException):
|
||||||
"""
|
"""
|
||||||
|
|
@ -47,6 +51,12 @@ class ProcessExitedException(ProcessException):
|
||||||
self.exit_code = exit_code
|
self.exit_code = exit_code
|
||||||
self.signal_name = signal_name
|
self.signal_name = signal_name
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
return (
|
||||||
|
type(self),
|
||||||
|
(self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _wrap(fn, i, args, error_queue):
|
def _wrap(fn, i, args, error_queue):
|
||||||
# prctl(2) is a Linux specific system call.
|
# prctl(2) is a Linux specific system call.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user