pytorch/torch/nn/parallel/parallel_apply.py
Matthew Hoffman 6cabc105bb Merge type stubs for torch.nn.parallel (#101528)
Fixes #91648

As explained in the tracking issue, the incomplete type stubs in `torch/nn/parallel` mask `DataParallel` methods relevant for subclassing and also mask type issues present in the code as well.

One notable change here is the addition of [`allow_redefinition = True`](https://mypy.readthedocs.io/en/stable/config_file.html#confval-allow_redefinition) in `mypy.ini`, which allows for a common pattern:

> Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition.

This is added specifically to allow for the type narrowing of `device_ids` in `torch.nn.parallel.data_parallel.data_parallel` from `Sequence[Union[int, torch.device]]` to `Sequence[int]`.

Other than this, there are various renamings and `type: ignore` comments added to bypass errors that arose from the merging.

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101528
Approved by: https://github.com/ezyang
2023-05-24 16:52:13 +00:00

113 lines
4.3 KiB
Python

import threading
import torch
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from ..modules import Module
from torch.cuda._utils import _get_device_index
from torch.cuda.amp import autocast
from torch._utils import ExceptionWrapper
__all__ = ['parallel_apply']
def _get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]:
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, (list, tuple)):
for result in map(_get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(_get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
def parallel_apply(
modules: Sequence[Module],
inputs: Sequence[Any],
kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
) -> List[Any]:
r"""Applies each `module` in :attr:`modules` in parallel on arguments
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`.
Args:
modules (Module): modules to be parallelized
inputs (tensor): inputs to the modules
devices (list of int or torch.device): CUDA devices
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}'
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
devices = [_get_device_index(x, True) for x in devices]
streams = [torch.cuda.current_stream(x) for x in devices]
lock = threading.Lock()
results = {}
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
def _worker(
i: int,
module: Module,
input: Any,
kwargs: Dict[str, Any],
device: Optional[Union[int, torch.device]] = None,
stream: Optional[torch.cuda.Stream] = None,
) -> None:
torch.set_grad_enabled(grad_enabled)
if device is None:
t = _get_a_var(input)
if t is None:
with lock:
results[i] = ExceptionWrapper(
where="in replica {}, no device was provided and no tensor input was found; "
"device cannot be resolved".format(i))
return
device = t.get_device()
if stream is None:
stream = torch.cuda.current_stream(device)
try:
with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = module(*input, **kwargs)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(
where="in replica {} on device {}".format(i, device))
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device, stream))
for i, (module, input, kwargs, device, stream) in
enumerate(zip(modules, inputs, kwargs_tup, devices, streams))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs