mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Fixes #91648 As explained in the tracking issue, the incomplete type stubs in `torch/nn/parallel` mask `DataParallel` methods relevant for subclassing and also mask type issues present in the code as well. One notable change here is the addition of [`allow_redefinition = True`](https://mypy.readthedocs.io/en/stable/config_file.html#confval-allow_redefinition) in `mypy.ini`, which allows for a common pattern: > Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition. This is added specifically to allow for the type narrowing of `device_ids` in `torch.nn.parallel.data_parallel.data_parallel` from `Sequence[Union[int, torch.device]]` to `Sequence[int]`. Other than this, there are various renamings and `type: ignore` comments added to bypass errors that arose from the merging. @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/101528 Approved by: https://github.com/ezyang
113 lines
4.3 KiB
Python
113 lines
4.3 KiB
Python
import threading
|
|
import torch
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
|
from ..modules import Module
|
|
from torch.cuda._utils import _get_device_index
|
|
from torch.cuda.amp import autocast
|
|
from torch._utils import ExceptionWrapper
|
|
|
|
__all__ = ['parallel_apply']
|
|
|
|
def _get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]:
|
|
if isinstance(obj, torch.Tensor):
|
|
return obj
|
|
|
|
if isinstance(obj, (list, tuple)):
|
|
for result in map(_get_a_var, obj):
|
|
if isinstance(result, torch.Tensor):
|
|
return result
|
|
if isinstance(obj, dict):
|
|
for result in map(_get_a_var, obj.items()):
|
|
if isinstance(result, torch.Tensor):
|
|
return result
|
|
return None
|
|
|
|
def parallel_apply(
|
|
modules: Sequence[Module],
|
|
inputs: Sequence[Any],
|
|
kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
|
|
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
|
|
) -> List[Any]:
|
|
r"""Applies each `module` in :attr:`modules` in parallel on arguments
|
|
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
|
|
on each of :attr:`devices`.
|
|
|
|
Args:
|
|
modules (Module): modules to be parallelized
|
|
inputs (tensor): inputs to the modules
|
|
devices (list of int or torch.device): CUDA devices
|
|
|
|
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
|
|
:attr:`devices` (if given) should all have same length. Moreover, each
|
|
element of :attr:`inputs` can either be a single object as the only argument
|
|
to a module, or a collection of positional arguments.
|
|
"""
|
|
assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}'
|
|
if kwargs_tup is not None:
|
|
assert len(modules) == len(kwargs_tup)
|
|
else:
|
|
kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
|
|
if devices is not None:
|
|
assert len(modules) == len(devices)
|
|
else:
|
|
devices = [None] * len(modules)
|
|
devices = [_get_device_index(x, True) for x in devices]
|
|
streams = [torch.cuda.current_stream(x) for x in devices]
|
|
lock = threading.Lock()
|
|
results = {}
|
|
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
|
|
|
|
def _worker(
|
|
i: int,
|
|
module: Module,
|
|
input: Any,
|
|
kwargs: Dict[str, Any],
|
|
device: Optional[Union[int, torch.device]] = None,
|
|
stream: Optional[torch.cuda.Stream] = None,
|
|
) -> None:
|
|
torch.set_grad_enabled(grad_enabled)
|
|
if device is None:
|
|
t = _get_a_var(input)
|
|
if t is None:
|
|
with lock:
|
|
results[i] = ExceptionWrapper(
|
|
where="in replica {}, no device was provided and no tensor input was found; "
|
|
"device cannot be resolved".format(i))
|
|
return
|
|
device = t.get_device()
|
|
if stream is None:
|
|
stream = torch.cuda.current_stream(device)
|
|
try:
|
|
with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
|
|
# this also avoids accidental slicing of `input` if it is a Tensor
|
|
if not isinstance(input, (list, tuple)):
|
|
input = (input,)
|
|
output = module(*input, **kwargs)
|
|
with lock:
|
|
results[i] = output
|
|
except Exception:
|
|
with lock:
|
|
results[i] = ExceptionWrapper(
|
|
where="in replica {} on device {}".format(i, device))
|
|
|
|
if len(modules) > 1:
|
|
threads = [threading.Thread(target=_worker,
|
|
args=(i, module, input, kwargs, device, stream))
|
|
for i, (module, input, kwargs, device, stream) in
|
|
enumerate(zip(modules, inputs, kwargs_tup, devices, streams))]
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
else:
|
|
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
|
|
|
|
outputs = []
|
|
for i in range(len(inputs)):
|
|
output = results[i]
|
|
if isinstance(output, ExceptionWrapper):
|
|
output.reraise()
|
|
outputs.append(output)
|
|
return outputs
|