pytorch/torch/nn/parallel/parallel_apply.py
Jan Schlüter 0bc90194fb Catch and print exception traceback in parallel_apply() workers (#18055)
Summary:
When an exception occurs in one of the modules passed to `parallel_apply()`, it is caught and re-raised in the main thread. This preserves the original exception type and message, but has the traceback point at the position where it's re-raised, rather than the original point of failure.

This PR saves the exception information required to generate the traceback, and includes the original traceback in the message of the exception raised in the main thread.

Before:
```
  ...
  File ".../torch/nn/parallel/data_parallel.py", line 153, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File ".../torch/nn/parallel/parallel_apply.py", line 84, in parallel_apply
    raise output
RuntimeError: expected type torch.FloatTensor but got torch.cuda.FloatTensor
```

After:
```
  ...
  File ".../torch/nn/parallel/data_parallel.py", line 153, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File ".../torch/nn/parallel/parallel_apply.py", line 88, in parallel_apply
    ''.join(traceback.format_exception(*exc_info)))
RuntimeError: Caught exception in replica 0. Original traceback and message:
Traceback (most recent call last):
  ...
  File "../models/foo.py", line 319, in bar
    baz = asdf / ghij[:, np.newaxis]
RuntimeError: expected type torch.FloatTensor but got torch.cuda.FloatTensor
```

I took care to raise an exception of the original type (in case the main code checks for that), but replaced the message. It helped me find a bug that did not occur outside `data_parallel()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18055

Differential Revision: D16444972

Pulled By: zhangguanheng66

fbshipit-source-id: ec436c9d4677fad18106a8046cfa835a20a101ce
2019-07-26 11:41:22 -07:00

88 lines
3.0 KiB
Python

import threading
import torch
from torch.cuda._utils import _get_device_index
from torch._utils import ExceptionWrapper
def get_a_var(obj):
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
r"""Applies each `module` in :attr:`modules` in parallel on arguments
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`.
Args:
modules (Module): modules to be parallelized
inputs (tensor): inputs to the modules
devices (list of int or torch.device): CUDA devices
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(inputs)
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
devices = list(map(lambda x: _get_device_index(x, True), devices))
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = module(*input, **kwargs)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(
where="in replica {} on device {}".format(i, device))
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in
enumerate(zip(modules, inputs, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs