mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings. I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :) Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519 Approved by: https://github.com/ezyang
113 lines
4.3 KiB
Python
113 lines
4.3 KiB
Python
import threading
|
|
import torch
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
|
from ..modules import Module
|
|
from torch.cuda._utils import _get_device_index
|
|
from torch.cuda.amp import autocast
|
|
from torch._utils import ExceptionWrapper
|
|
|
|
__all__ = ['get_a_var', 'parallel_apply']
|
|
|
|
def get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]:
|
|
if isinstance(obj, torch.Tensor):
|
|
return obj
|
|
|
|
if isinstance(obj, (list, tuple)):
|
|
for result in map(get_a_var, obj):
|
|
if isinstance(result, torch.Tensor):
|
|
return result
|
|
if isinstance(obj, dict):
|
|
for result in map(get_a_var, obj.items()):
|
|
if isinstance(result, torch.Tensor):
|
|
return result
|
|
return None
|
|
|
|
def parallel_apply(
|
|
modules: Sequence[Module],
|
|
inputs: Sequence[Any],
|
|
kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
|
|
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
|
|
) -> List[Any]:
|
|
r"""Applies each `module` in :attr:`modules` in parallel on arguments
|
|
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
|
|
on each of :attr:`devices`.
|
|
|
|
Args:
|
|
modules (Module): modules to be parallelized
|
|
inputs (tensor): inputs to the modules
|
|
devices (list of int or torch.device): CUDA devices
|
|
|
|
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
|
|
:attr:`devices` (if given) should all have same length. Moreover, each
|
|
element of :attr:`inputs` can either be a single object as the only argument
|
|
to a module, or a collection of positional arguments.
|
|
"""
|
|
assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}'
|
|
if kwargs_tup is not None:
|
|
assert len(modules) == len(kwargs_tup)
|
|
else:
|
|
kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
|
|
if devices is not None:
|
|
assert len(modules) == len(devices)
|
|
else:
|
|
devices = [None] * len(modules)
|
|
devices = [_get_device_index(x, True) for x in devices]
|
|
streams = [torch.cuda.current_stream(x) for x in devices]
|
|
lock = threading.Lock()
|
|
results = {}
|
|
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
|
|
|
|
def _worker(
|
|
i: int,
|
|
module: Module,
|
|
input: Any,
|
|
kwargs: Dict[str, Any],
|
|
device: Optional[Union[int, torch.device]] = None,
|
|
stream: Optional[torch.cuda.Stream] = None,
|
|
) -> None:
|
|
torch.set_grad_enabled(grad_enabled)
|
|
if device is None:
|
|
t = get_a_var(input)
|
|
if t is None:
|
|
with lock:
|
|
results[i] = ExceptionWrapper(
|
|
where=f"in replica {i}, no device was provided and no tensor input was found; "
|
|
"device cannot be resolved")
|
|
return
|
|
device = t.get_device()
|
|
if stream is None:
|
|
stream = torch.cuda.current_stream(device)
|
|
try:
|
|
with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
|
|
# this also avoids accidental slicing of `input` if it is a Tensor
|
|
if not isinstance(input, (list, tuple)):
|
|
input = (input,)
|
|
output = module(*input, **kwargs)
|
|
with lock:
|
|
results[i] = output
|
|
except Exception:
|
|
with lock:
|
|
results[i] = ExceptionWrapper(
|
|
where=f"in replica {i} on device {device}")
|
|
|
|
if len(modules) > 1:
|
|
threads = [threading.Thread(target=_worker,
|
|
args=(i, module, input, kwargs, device, stream))
|
|
for i, (module, input, kwargs, device, stream) in
|
|
enumerate(zip(modules, inputs, kwargs_tup, devices, streams))]
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
else:
|
|
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
|
|
|
|
outputs = []
|
|
for i in range(len(inputs)):
|
|
output = results[i]
|
|
if isinstance(output, ExceptionWrapper):
|
|
output.reraise()
|
|
outputs.append(output)
|
|
return outputs
|