mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
When an exception occurs in one of the modules passed to `parallel_apply()`, it is caught and re-raised in the main thread. This preserves the original exception type and message, but has the traceback point at the position where it's re-raised, rather than the original point of failure.
This PR saves the exception information required to generate the traceback, and includes the original traceback in the message of the exception raised in the main thread.
Before:
```
...
File ".../torch/nn/parallel/data_parallel.py", line 153, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File ".../torch/nn/parallel/parallel_apply.py", line 84, in parallel_apply
raise output
RuntimeError: expected type torch.FloatTensor but got torch.cuda.FloatTensor
```
After:
```
...
File ".../torch/nn/parallel/data_parallel.py", line 153, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File ".../torch/nn/parallel/parallel_apply.py", line 88, in parallel_apply
''.join(traceback.format_exception(*exc_info)))
RuntimeError: Caught exception in replica 0. Original traceback and message:
Traceback (most recent call last):
...
File "../models/foo.py", line 319, in bar
baz = asdf / ghij[:, np.newaxis]
RuntimeError: expected type torch.FloatTensor but got torch.cuda.FloatTensor
```
I took care to raise an exception of the original type (in case the main code checks for that), but replaced the message. It helped me find a bug that did not occur outside `data_parallel()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18055
Differential Revision: D16444972
Pulled By: zhangguanheng66
fbshipit-source-id: ec436c9d4677fad18106a8046cfa835a20a101ce
88 lines
3.0 KiB
Python
88 lines
3.0 KiB
Python
import threading
|
|
import torch
|
|
from torch.cuda._utils import _get_device_index
|
|
from torch._utils import ExceptionWrapper
|
|
|
|
|
|
def get_a_var(obj):
|
|
if isinstance(obj, torch.Tensor):
|
|
return obj
|
|
|
|
if isinstance(obj, list) or isinstance(obj, tuple):
|
|
for result in map(get_a_var, obj):
|
|
if isinstance(result, torch.Tensor):
|
|
return result
|
|
if isinstance(obj, dict):
|
|
for result in map(get_a_var, obj.items()):
|
|
if isinstance(result, torch.Tensor):
|
|
return result
|
|
return None
|
|
|
|
|
|
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
|
|
r"""Applies each `module` in :attr:`modules` in parallel on arguments
|
|
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
|
|
on each of :attr:`devices`.
|
|
|
|
Args:
|
|
modules (Module): modules to be parallelized
|
|
inputs (tensor): inputs to the modules
|
|
devices (list of int or torch.device): CUDA devices
|
|
|
|
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
|
|
:attr:`devices` (if given) should all have same length. Moreover, each
|
|
element of :attr:`inputs` can either be a single object as the only argument
|
|
to a module, or a collection of positional arguments.
|
|
"""
|
|
assert len(modules) == len(inputs)
|
|
if kwargs_tup is not None:
|
|
assert len(modules) == len(kwargs_tup)
|
|
else:
|
|
kwargs_tup = ({},) * len(modules)
|
|
if devices is not None:
|
|
assert len(modules) == len(devices)
|
|
else:
|
|
devices = [None] * len(modules)
|
|
devices = list(map(lambda x: _get_device_index(x, True), devices))
|
|
lock = threading.Lock()
|
|
results = {}
|
|
grad_enabled = torch.is_grad_enabled()
|
|
|
|
def _worker(i, module, input, kwargs, device=None):
|
|
torch.set_grad_enabled(grad_enabled)
|
|
if device is None:
|
|
device = get_a_var(input).get_device()
|
|
try:
|
|
with torch.cuda.device(device):
|
|
# this also avoids accidental slicing of `input` if it is a Tensor
|
|
if not isinstance(input, (list, tuple)):
|
|
input = (input,)
|
|
output = module(*input, **kwargs)
|
|
with lock:
|
|
results[i] = output
|
|
except Exception:
|
|
with lock:
|
|
results[i] = ExceptionWrapper(
|
|
where="in replica {} on device {}".format(i, device))
|
|
|
|
if len(modules) > 1:
|
|
threads = [threading.Thread(target=_worker,
|
|
args=(i, module, input, kwargs, device))
|
|
for i, (module, input, kwargs, device) in
|
|
enumerate(zip(modules, inputs, kwargs_tup, devices))]
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
else:
|
|
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
|
|
|
|
outputs = []
|
|
for i in range(len(inputs)):
|
|
output = results[i]
|
|
if isinstance(output, ExceptionWrapper):
|
|
output.reraise()
|
|
outputs.append(output)
|
|
return outputs
|