Improve torch.cuda.amp type hints (#108630)

Fixes #108629

1. Add the following to their modules' `__all__` so that pyright considers them to be publicly exported:
* [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast)
* [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler)
* [`torch.cuda.amp.autocast`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast)
* [`torch.cuda.amp.custom_fwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd)
* [`torch.cuda.amp.custom_bwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_bwd)
2. Add `overload`s for `torch.cuda.amp.GradScaler.scale` to differentiate when a `torch.Tensor` is returned vs. an `Iterable[torch.Tensor]` is returned based on the type of the `outputs` parameter.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108630
Approved by: https://github.com/ezyang
This commit is contained in:
Matthew Hoffman 2023-09-08 06:06:21 +00:00 committed by PyTorch MergeBot
parent 6c7260407b
commit e40d6ae0a7
5 changed files with 157 additions and 109 deletions

View File

@ -332,7 +332,7 @@ def generate_tensor_like_torch_implementations():
# the problem. A more proper fix is to make the "not tested" check
# a test on its own, and to make sure the monkeypatch is only installed
# for the span of the relevant test (and deleted afterwards)
testing_ignore = {"sample_functional"}
testing_ignore = {"sample_functional", "autocast"}
for namespace, funcs in get_overridable_functions().items():
for func in funcs:
if func not in testing_overrides and func.__name__ not in testing_ignore:

View File

@ -55,7 +55,7 @@ __all__ = [
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'SymBool', 'sym_not',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
'export',
'export', 'autocast',
]
################################################################################

View File

@ -1,2 +1,9 @@
from .autocast_mode import autocast, custom_bwd, custom_fwd # noqa: F401
from .grad_scaler import GradScaler # noqa: F401
from .autocast_mode import autocast, custom_bwd, custom_fwd
from .grad_scaler import GradScaler
__all__ = [
"autocast",
"custom_bwd",
"custom_fwd",
"GradScaler",
]

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import inspect
import warnings
from collections import abc, defaultdict
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Tuple
from typing import Any, cast, Dict, Iterable, List, Optional, overload, Tuple, Union
import torch
from .common import amp_definitely_not_available
@ -21,7 +23,7 @@ class _MultiDeviceReplicator:
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
def get(self, device) -> torch.Tensor:
def get(self, device: torch.device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
@ -40,14 +42,11 @@ class OptState(Enum):
STEPPED = 2
def _refresh_per_optimizer_state():
def _refresh_per_optimizer_state() -> Dict[str, Any]:
return {"stage": OptState.READY, "found_inf_per_device": {}}
class GradScaler:
_scale: Optional[torch.Tensor]
_grows_tracker: Optional[torch.Tensor]
_per_optimizer_states: Dict[int, Dict[str, Any]]
"""
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
conveniently.
@ -115,12 +114,12 @@ class GradScaler:
def __init__(
self,
init_scale=2.0**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True,
):
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
enabled: bool = True,
) -> None:
if enabled and amp_definitely_not_available():
warnings.warn(
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling."
@ -135,17 +134,19 @@ class GradScaler:
self._init_scale = init_scale
# self._scale will be lazily initialized during the first call to scale()
self._scale = None
self._scale: Optional[torch.Tensor] = None
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker = None
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
self._growth_tracker: Optional[torch.Tensor] = None
self._per_optimizer_states: Dict[int, Dict[str, Any]] = defaultdict(
_refresh_per_optimizer_state
)
def _check_scale_growth_tracker(
self, funcname
self, funcname: str
) -> Tuple[torch.Tensor, torch.Tensor]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, (
@ -156,14 +157,33 @@ class GradScaler:
)
return (self._scale, self._growth_tracker)
def _lazy_init_scale_growth_tracker(self, dev):
def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None:
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
self._growth_tracker = torch.full(
(), self._init_growth_tracker, dtype=torch.int32, device=dev
)
def scale(self, outputs):
@overload
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
...
@overload
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
...
@overload
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
...
def scale(
self,
outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
@ -189,7 +209,7 @@ class GradScaler:
_MultiDeviceReplicator
] = [] # holds a reference that can be overwritten by apply_scale
def apply_scale(val):
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
if isinstance(val, torch.Tensor):
assert val.is_cuda or val.device.type == "xla"
if len(stash) == 0:
@ -198,18 +218,22 @@ class GradScaler:
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
elif isinstance(val, abc.Iterable):
if isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return iterable
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
def _unscale_grads_(
self,
optimizer: torch.optim.Optimizer,
inv_scale: torch.Tensor,
found_inf: torch.Tensor,
allow_fp16: bool,
) -> Dict[torch.device, torch.Tensor]:
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
@ -219,10 +243,13 @@ class GradScaler:
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
per_device_and_dtype_grads: Dict[
torch.device, Dict[torch.dtype, List[torch.Tensor]]
] = defaultdict(lambda: defaultdict(list))
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
assert isinstance(param, torch.Tensor)
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
@ -253,7 +280,7 @@ class GradScaler:
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer):
def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
@ -309,13 +336,21 @@ class GradScaler:
)
optimizer_state["stage"] = OptState.UNSCALED
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
retval = None
def _maybe_opt_step(
self,
optimizer: torch.optim.Optimizer,
optimizer_state: Dict[str, Any],
*args: Any,
**kwargs: Any,
) -> Optional[float]:
retval: Optional[float] = None
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
retval = optimizer.step(*args, **kwargs)
return retval
def step(self, optimizer, *args, **kwargs):
def step(
self, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any
) -> Optional[float]:
"""
:meth:`step` carries out the following two operations:
@ -353,12 +388,9 @@ class GradScaler:
"step() has already been called since the last update()."
)
retval = None
retval: Optional[float] = None
if (
hasattr(optimizer, "_step_supports_amp_scaling")
and optimizer._step_supports_amp_scaling
):
if getattr(optimizer, "_step_supports_amp_scaling", False):
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
# The contract with custom optimizers is that their step() should accept an additional,
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
@ -386,6 +418,7 @@ class GradScaler:
if optimizer_state["stage"] is OptState.READY:
self._check_inf_per_device(optimizer)
scaler = self._get_scale_async()
assert scaler is not None
found_inf = cast(
torch.Tensor,
sum(
@ -395,15 +428,15 @@ class GradScaler:
]
),
)
optimizer.grad_scale = (
optimizer.grad_scale = ( # type: ignore[attr-defined]
None if optimizer_state["stage"] == OptState.UNSCALED else scaler
)
optimizer.found_inf = found_inf
optimizer.found_inf = found_inf # type: ignore[attr-defined]
retval = optimizer.step(*args, **kwargs_)
optimizer_state["stage"] = OptState.STEPPED
if not has_grad_scaler_kwarg:
del optimizer.grad_scale
del optimizer.found_inf
del optimizer.grad_scale # type: ignore[attr-defined]
del optimizer.found_inf # type: ignore[attr-defined]
return retval
if optimizer_state["stage"] is OptState.READY:
@ -419,7 +452,7 @@ class GradScaler:
return retval
def update(self, new_scale=None):
def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
"""
Updates the scale factor.
@ -451,15 +484,16 @@ class GradScaler:
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
assert self._scale is not None
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
self._scale.fill_(new_scale)
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
self._scale.copy_(new_scale)
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
@ -488,10 +522,10 @@ class GradScaler:
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _get_scale_async(self):
def _get_scale_async(self) -> Optional[torch.Tensor]:
return self._scale
def get_scale(self):
def get_scale(self) -> float:
"""
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
@ -501,68 +535,66 @@ class GradScaler:
if self._enabled:
return (
self._init_scale
if self._scale is None
else self._get_scale_async().item()
if (scale := self._get_scale_async()) is None
else cast(float, scale.item())
)
else:
return 1.0
return 1.0
def get_growth_factor(self):
def get_growth_factor(self) -> float:
r"""
Returns a Python float containing the scale growth factor.
"""
return self._growth_factor
def set_growth_factor(self, new_factor):
def set_growth_factor(self, new_factor: float) -> None:
r"""
Args:
new_scale (float): Value to use as the new scale growth factor.
"""
self._growth_factor = new_factor
def get_backoff_factor(self):
def get_backoff_factor(self) -> float:
r"""
Returns a Python float containing the scale backoff factor.
"""
return self._backoff_factor
def set_backoff_factor(self, new_factor):
def set_backoff_factor(self, new_factor: float) -> None:
r"""
Args:
new_scale (float): Value to use as the new scale backoff factor.
"""
self._backoff_factor = new_factor
def get_growth_interval(self):
def get_growth_interval(self) -> int:
r"""
Returns a Python int containing the growth interval.
"""
return self._growth_interval
def set_growth_interval(self, new_interval):
def set_growth_interval(self, new_interval: int) -> None:
r"""
Args:
new_interval (int): Value to use as the new growth interval.
"""
self._growth_interval = new_interval
def _get_growth_tracker(self):
def _get_growth_tracker(self) -> int:
if self._enabled:
return (
self._init_growth_tracker
if self._growth_tracker is None
else self._growth_tracker.item()
else cast(int, self._growth_tracker.item())
)
else:
return 0
return 0
def is_enabled(self):
def is_enabled(self) -> bool:
r"""
Returns a bool indicating whether this instance is enabled.
"""
return self._enabled
def state_dict(self):
def state_dict(self) -> Dict[str, Any]:
r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries:
@ -578,19 +610,17 @@ class GradScaler:
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return (
{
if self._enabled:
return {
"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker(),
}
if self._enabled
else {}
)
return {}
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
@ -606,17 +636,17 @@ class GradScaler:
"from a disabled instance of GradScaler."
)
self._init_scale = state_dict["scale"]
self._init_scale = cast(float, state_dict["scale"])
if self._scale is not None:
self._scale.fill_(state_dict["scale"])
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._growth_interval = state_dict["growth_interval"]
self._init_growth_tracker = state_dict["_growth_tracker"]
self._growth_factor = cast(float, state_dict["growth_factor"])
self._backoff_factor = cast(float, state_dict["backoff_factor"])
self._growth_interval = cast(int, state_dict["growth_interval"])
self._init_growth_tracker = cast(int, state_dict["_growth_tracker"])
if self._growth_tracker is not None:
self._growth_tracker.fill_(state_dict["_growth_tracker"])
def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
if self._enabled:
assert len(self._per_optimizer_states) == 0, (
@ -632,10 +662,10 @@ class GradScaler:
state["_growth_tracker"] = None
return state
def __setstate__(self, state):
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
def _check_inf_per_device(self, optimizer):
def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
@ -647,5 +677,5 @@ class GradScaler:
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
def _found_inf_per_device(self, optimizer):
def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]

View File

@ -1,21 +1,20 @@
import logging
from collections import abc, defaultdict
from typing import Dict, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union
import torch
import torch.distributed as dist
from torch.cuda import FloatTensor # type: ignore[attr-defined]
from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
from torch.distributed.distributed_c10d import ProcessGroup
log = logging.getLogger(__name__)
def _refresh_per_optimizer_state():
def _refresh_per_optimizer_state() -> Dict[str, Any]:
return {"stage": OptState.READY, "found_inf_per_device": {}}
def _is_supported_device(tensor: torch.Tensor):
def _is_supported_device(tensor: torch.Tensor) -> bool:
return tensor.is_cuda or tensor.device.type in ("xla", "cpu")
@ -88,7 +87,7 @@ class ShardedGradScaler(GradScaler):
growth_interval: int = 2000,
enabled: bool = True,
process_group: Optional[ProcessGroup] = dist.group.WORLD,
):
) -> None:
super().__init__(
init_scale=init_scale,
backoff_factor=backoff_factor,
@ -100,9 +99,25 @@ class ShardedGradScaler(GradScaler):
self.process_group = process_group
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
@overload
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
...
@overload
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
...
@overload
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
...
def scale(
self, outputs: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
if not self._enabled:
return outputs
@ -121,9 +136,7 @@ class ShardedGradScaler(GradScaler):
stash: List[_GeneralMultiDeviceReplicator] = []
def apply_scale(
val: Union[torch.Tensor, abc.Iterable]
) -> Union[torch.Tensor, abc.Iterable]:
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
if isinstance(val, torch.Tensor):
assert _is_supported_device(val)
if len(stash) == 0:
@ -136,19 +149,20 @@ class ShardedGradScaler(GradScaler):
# For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
# format (fp16, bf16) and so the scaled loss should be of the same dtype.
return scaled_val.type(val.dtype)
elif isinstance(val, abc.Iterable):
if isinstance(val, abc.Iterable):
iterator = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterator)
else:
return iterator
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return iterator
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs) # type: ignore[return-value]
return apply_scale(outputs)
def _foreach_non_finite_check_and_unscale_cpu_(
self, grads: List, found_inf: torch.Tensor, inv_scale: torch.Tensor
self,
grads: Sequence[torch.Tensor],
found_inf: torch.Tensor,
inv_scale: torch.Tensor,
) -> None:
if len(grads) == 0:
return
@ -288,28 +302,25 @@ class ShardedGradScaler(GradScaler):
if future_handles:
torch.futures.wait_all(future_handles)
def step(
self, optimizer: torch.optim.Optimizer, *args, **kwargs
) -> Optional[float]:
return super().step(optimizer, *args, **kwargs)
def _amp_update_scale_cpu_(self, found_inf) -> None:
def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None:
"""
If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
"""
assert self._scale is not None and self._growth_tracker is not None
if found_inf.item() >= 1.0:
self._scale *= self._backoff_factor # type: ignore[arg-type]
self._growth_tracker = 0
self._scale *= self._backoff_factor
self._growth_tracker.fill_(0)
else:
successful = self._growth_tracker + 1 # type: ignore[operator]
if successful == self._growth_interval: # type: ignore[arg-type]
self._scale *= self._growth_factor # type: ignore[arg-type]
self._growth_tracker = 0
successful = self._growth_tracker + 1
if successful == self._growth_interval:
self._scale *= self._growth_factor
self._growth_tracker.fill_(0)
else:
self._growth_tracker = successful
def update(self, new_scale: Optional[Union[float, FloatTensor]] = None) -> None:
def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``