Improve torch.cuda.amp type hints (#108630)

Fixes #108629

1. Add the following to their modules' `__all__` so that pyright considers them to be publicly exported:
* [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast)
* [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler)
* [`torch.cuda.amp.autocast`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast)
* [`torch.cuda.amp.custom_fwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd)
* [`torch.cuda.amp.custom_bwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_bwd)
2. Add `overload`s for `torch.cuda.amp.GradScaler.scale` to differentiate when a `torch.Tensor` is returned vs. an `Iterable[torch.Tensor]` is returned based on the type of the `outputs` parameter.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108630
Approved by: https://github.com/ezyang
This commit is contained in:
Matthew Hoffman 2023-09-08 06:06:21 +00:00 committed by PyTorch MergeBot
parent 6c7260407b
commit e40d6ae0a7
5 changed files with 157 additions and 109 deletions

View File

@ -332,7 +332,7 @@ def generate_tensor_like_torch_implementations():
# the problem. A more proper fix is to make the "not tested" check # the problem. A more proper fix is to make the "not tested" check
# a test on its own, and to make sure the monkeypatch is only installed # a test on its own, and to make sure the monkeypatch is only installed
# for the span of the relevant test (and deleted afterwards) # for the span of the relevant test (and deleted afterwards)
testing_ignore = {"sample_functional"} testing_ignore = {"sample_functional", "autocast"}
for namespace, funcs in get_overridable_functions().items(): for namespace, funcs in get_overridable_functions().items():
for func in funcs: for func in funcs:
if func not in testing_overrides and func.__name__ not in testing_ignore: if func not in testing_overrides and func.__name__ not in testing_ignore:

View File

@ -55,7 +55,7 @@ __all__ = [
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'SymBool', 'sym_not', 'SymBool', 'sym_not',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap', 'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
'export', 'export', 'autocast',
] ]
################################################################################ ################################################################################

View File

@ -1,2 +1,9 @@
from .autocast_mode import autocast, custom_bwd, custom_fwd # noqa: F401 from .autocast_mode import autocast, custom_bwd, custom_fwd
from .grad_scaler import GradScaler # noqa: F401 from .grad_scaler import GradScaler
__all__ = [
"autocast",
"custom_bwd",
"custom_fwd",
"GradScaler",
]

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import inspect import inspect
import warnings import warnings
from collections import abc, defaultdict from collections import abc, defaultdict
from enum import Enum from enum import Enum
from typing import Any, cast, Dict, List, Optional, Tuple from typing import Any, cast, Dict, Iterable, List, Optional, overload, Tuple, Union
import torch import torch
from .common import amp_definitely_not_available from .common import amp_definitely_not_available
@ -21,7 +23,7 @@ class _MultiDeviceReplicator:
self.master = master_tensor self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
def get(self, device) -> torch.Tensor: def get(self, device: torch.device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None) retval = self._per_device_tensors.get(device, None)
if retval is None: if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True) retval = self.master.to(device=device, non_blocking=True, copy=True)
@ -40,14 +42,11 @@ class OptState(Enum):
STEPPED = 2 STEPPED = 2
def _refresh_per_optimizer_state(): def _refresh_per_optimizer_state() -> Dict[str, Any]:
return {"stage": OptState.READY, "found_inf_per_device": {}} return {"stage": OptState.READY, "found_inf_per_device": {}}
class GradScaler: class GradScaler:
_scale: Optional[torch.Tensor]
_grows_tracker: Optional[torch.Tensor]
_per_optimizer_states: Dict[int, Dict[str, Any]]
""" """
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
conveniently. conveniently.
@ -115,12 +114,12 @@ class GradScaler:
def __init__( def __init__(
self, self,
init_scale=2.0**16, init_scale: float = 2.0**16,
growth_factor=2.0, growth_factor: float = 2.0,
backoff_factor=0.5, backoff_factor: float = 0.5,
growth_interval=2000, growth_interval: int = 2000,
enabled=True, enabled: bool = True,
): ) -> None:
if enabled and amp_definitely_not_available(): if enabled and amp_definitely_not_available():
warnings.warn( warnings.warn(
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling." "torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling."
@ -135,17 +134,19 @@ class GradScaler:
self._init_scale = init_scale self._init_scale = init_scale
# self._scale will be lazily initialized during the first call to scale() # self._scale will be lazily initialized during the first call to scale()
self._scale = None self._scale: Optional[torch.Tensor] = None
self._growth_factor = growth_factor self._growth_factor = growth_factor
self._backoff_factor = backoff_factor self._backoff_factor = backoff_factor
self._growth_interval = growth_interval self._growth_interval = growth_interval
self._init_growth_tracker = 0 self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale() # self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker = None self._growth_tracker: Optional[torch.Tensor] = None
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) self._per_optimizer_states: Dict[int, Dict[str, Any]] = defaultdict(
_refresh_per_optimizer_state
)
def _check_scale_growth_tracker( def _check_scale_growth_tracker(
self, funcname self, funcname: str
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, ( assert self._scale is not None, (
@ -156,14 +157,33 @@ class GradScaler:
) )
return (self._scale, self._growth_tracker) return (self._scale, self._growth_tracker)
def _lazy_init_scale_growth_tracker(self, dev): def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None:
assert self._growth_tracker is None, "_growth_tracker initialized before _scale" assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev) self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
self._growth_tracker = torch.full( self._growth_tracker = torch.full(
(), self._init_growth_tracker, dtype=torch.int32, device=dev (), self._init_growth_tracker, dtype=torch.int32, device=dev
) )
def scale(self, outputs): @overload
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
...
@overload
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
...
@overload
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
...
def scale(
self,
outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
""" """
Multiplies ('scales') a tensor or list of tensors by the scale factor. Multiplies ('scales') a tensor or list of tensors by the scale factor.
@ -189,7 +209,7 @@ class GradScaler:
_MultiDeviceReplicator _MultiDeviceReplicator
] = [] # holds a reference that can be overwritten by apply_scale ] = [] # holds a reference that can be overwritten by apply_scale
def apply_scale(val): def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
if isinstance(val, torch.Tensor): if isinstance(val, torch.Tensor):
assert val.is_cuda or val.device.type == "xla" assert val.is_cuda or val.device.type == "xla"
if len(stash) == 0: if len(stash) == 0:
@ -198,18 +218,22 @@ class GradScaler:
assert self._scale is not None assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale)) stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device) return val * stash[0].get(val.device)
elif isinstance(val, abc.Iterable): if isinstance(val, abc.Iterable):
iterable = map(apply_scale, val) iterable = map(apply_scale, val)
if isinstance(val, (list, tuple)): if isinstance(val, (list, tuple)):
return type(val)(iterable) return type(val)(iterable)
else: return iterable
return iterable raise ValueError("outputs must be a Tensor or an iterable of Tensors")
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs) return apply_scale(outputs)
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): def _unscale_grads_(
self,
optimizer: torch.optim.Optimizer,
inv_scale: torch.Tensor,
found_inf: torch.Tensor,
allow_fp16: bool,
) -> Dict[torch.device, torch.Tensor]:
per_device_inv_scale = _MultiDeviceReplicator(inv_scale) per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf) per_device_found_inf = _MultiDeviceReplicator(found_inf)
@ -219,10 +243,13 @@ class GradScaler:
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations. # Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] per_device_and_dtype_grads: Dict[
torch.device, Dict[torch.dtype, List[torch.Tensor]]
] = defaultdict(lambda: defaultdict(list))
with torch.no_grad(): with torch.no_grad():
for group in optimizer.param_groups: for group in optimizer.param_groups:
for param in group["params"]: for param in group["params"]:
assert isinstance(param, torch.Tensor)
if param.grad is None: if param.grad is None:
continue continue
if (not allow_fp16) and param.grad.dtype == torch.float16: if (not allow_fp16) and param.grad.dtype == torch.float16:
@ -253,7 +280,7 @@ class GradScaler:
return per_device_found_inf._per_device_tensors return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer): def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
""" """
Divides ("unscales") the optimizer's gradient tensors by the scale factor. Divides ("unscales") the optimizer's gradient tensors by the scale factor.
@ -309,13 +336,21 @@ class GradScaler:
) )
optimizer_state["stage"] = OptState.UNSCALED optimizer_state["stage"] = OptState.UNSCALED
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): def _maybe_opt_step(
retval = None self,
optimizer: torch.optim.Optimizer,
optimizer_state: Dict[str, Any],
*args: Any,
**kwargs: Any,
) -> Optional[float]:
retval: Optional[float] = None
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
retval = optimizer.step(*args, **kwargs) retval = optimizer.step(*args, **kwargs)
return retval return retval
def step(self, optimizer, *args, **kwargs): def step(
self, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any
) -> Optional[float]:
""" """
:meth:`step` carries out the following two operations: :meth:`step` carries out the following two operations:
@ -353,12 +388,9 @@ class GradScaler:
"step() has already been called since the last update()." "step() has already been called since the last update()."
) )
retval = None retval: Optional[float] = None
if ( if getattr(optimizer, "_step_supports_amp_scaling", False):
hasattr(optimizer, "_step_supports_amp_scaling")
and optimizer._step_supports_amp_scaling
):
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
# The contract with custom optimizers is that their step() should accept an additional, # The contract with custom optimizers is that their step() should accept an additional,
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
@ -386,6 +418,7 @@ class GradScaler:
if optimizer_state["stage"] is OptState.READY: if optimizer_state["stage"] is OptState.READY:
self._check_inf_per_device(optimizer) self._check_inf_per_device(optimizer)
scaler = self._get_scale_async() scaler = self._get_scale_async()
assert scaler is not None
found_inf = cast( found_inf = cast(
torch.Tensor, torch.Tensor,
sum( sum(
@ -395,15 +428,15 @@ class GradScaler:
] ]
), ),
) )
optimizer.grad_scale = ( optimizer.grad_scale = ( # type: ignore[attr-defined]
None if optimizer_state["stage"] == OptState.UNSCALED else scaler None if optimizer_state["stage"] == OptState.UNSCALED else scaler
) )
optimizer.found_inf = found_inf optimizer.found_inf = found_inf # type: ignore[attr-defined]
retval = optimizer.step(*args, **kwargs_) retval = optimizer.step(*args, **kwargs_)
optimizer_state["stage"] = OptState.STEPPED optimizer_state["stage"] = OptState.STEPPED
if not has_grad_scaler_kwarg: if not has_grad_scaler_kwarg:
del optimizer.grad_scale del optimizer.grad_scale # type: ignore[attr-defined]
del optimizer.found_inf del optimizer.found_inf # type: ignore[attr-defined]
return retval return retval
if optimizer_state["stage"] is OptState.READY: if optimizer_state["stage"] is OptState.READY:
@ -419,7 +452,7 @@ class GradScaler:
return retval return retval
def update(self, new_scale=None): def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
""" """
Updates the scale factor. Updates the scale factor.
@ -451,15 +484,16 @@ class GradScaler:
_scale, _growth_tracker = self._check_scale_growth_tracker("update") _scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None: if new_scale is not None:
assert self._scale is not None
# Accept a new user-defined scale. # Accept a new user-defined scale.
if isinstance(new_scale, float): if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr] self._scale.fill_(new_scale)
else: else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr] self._scale.copy_(new_scale)
else: else:
# Consume shared inf/nan data collected from optimizers to update the scale. # Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
@ -488,10 +522,10 @@ class GradScaler:
# To prepare for next iteration, clear the data collected from optimizers this iteration. # To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _get_scale_async(self): def _get_scale_async(self) -> Optional[torch.Tensor]:
return self._scale return self._scale
def get_scale(self): def get_scale(self) -> float:
""" """
Returns a Python float containing the current scale, or 1.0 if scaling is disabled. Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
@ -501,68 +535,66 @@ class GradScaler:
if self._enabled: if self._enabled:
return ( return (
self._init_scale self._init_scale
if self._scale is None if (scale := self._get_scale_async()) is None
else self._get_scale_async().item() else cast(float, scale.item())
) )
else: return 1.0
return 1.0
def get_growth_factor(self): def get_growth_factor(self) -> float:
r""" r"""
Returns a Python float containing the scale growth factor. Returns a Python float containing the scale growth factor.
""" """
return self._growth_factor return self._growth_factor
def set_growth_factor(self, new_factor): def set_growth_factor(self, new_factor: float) -> None:
r""" r"""
Args: Args:
new_scale (float): Value to use as the new scale growth factor. new_scale (float): Value to use as the new scale growth factor.
""" """
self._growth_factor = new_factor self._growth_factor = new_factor
def get_backoff_factor(self): def get_backoff_factor(self) -> float:
r""" r"""
Returns a Python float containing the scale backoff factor. Returns a Python float containing the scale backoff factor.
""" """
return self._backoff_factor return self._backoff_factor
def set_backoff_factor(self, new_factor): def set_backoff_factor(self, new_factor: float) -> None:
r""" r"""
Args: Args:
new_scale (float): Value to use as the new scale backoff factor. new_scale (float): Value to use as the new scale backoff factor.
""" """
self._backoff_factor = new_factor self._backoff_factor = new_factor
def get_growth_interval(self): def get_growth_interval(self) -> int:
r""" r"""
Returns a Python int containing the growth interval. Returns a Python int containing the growth interval.
""" """
return self._growth_interval return self._growth_interval
def set_growth_interval(self, new_interval): def set_growth_interval(self, new_interval: int) -> None:
r""" r"""
Args: Args:
new_interval (int): Value to use as the new growth interval. new_interval (int): Value to use as the new growth interval.
""" """
self._growth_interval = new_interval self._growth_interval = new_interval
def _get_growth_tracker(self): def _get_growth_tracker(self) -> int:
if self._enabled: if self._enabled:
return ( return (
self._init_growth_tracker self._init_growth_tracker
if self._growth_tracker is None if self._growth_tracker is None
else self._growth_tracker.item() else cast(int, self._growth_tracker.item())
) )
else: return 0
return 0
def is_enabled(self): def is_enabled(self) -> bool:
r""" r"""
Returns a bool indicating whether this instance is enabled. Returns a bool indicating whether this instance is enabled.
""" """
return self._enabled return self._enabled
def state_dict(self): def state_dict(self) -> Dict[str, Any]:
r""" r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries: Returns the state of the scaler as a :class:`dict`. It contains five entries:
@ -578,19 +610,17 @@ class GradScaler:
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`. should be called after :meth:`update`.
""" """
return ( if self._enabled:
{ return {
"scale": self.get_scale(), "scale": self.get_scale(),
"growth_factor": self._growth_factor, "growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor, "backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval, "growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker(), "_growth_tracker": self._get_growth_tracker(),
} }
if self._enabled return {}
else {}
)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
r""" r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
@ -606,17 +636,17 @@ class GradScaler:
"from a disabled instance of GradScaler." "from a disabled instance of GradScaler."
) )
self._init_scale = state_dict["scale"] self._init_scale = cast(float, state_dict["scale"])
if self._scale is not None: if self._scale is not None:
self._scale.fill_(state_dict["scale"]) self._scale.fill_(state_dict["scale"])
self._growth_factor = state_dict["growth_factor"] self._growth_factor = cast(float, state_dict["growth_factor"])
self._backoff_factor = state_dict["backoff_factor"] self._backoff_factor = cast(float, state_dict["backoff_factor"])
self._growth_interval = state_dict["growth_interval"] self._growth_interval = cast(int, state_dict["growth_interval"])
self._init_growth_tracker = state_dict["_growth_tracker"] self._init_growth_tracker = cast(int, state_dict["_growth_tracker"])
if self._growth_tracker is not None: if self._growth_tracker is not None:
self._growth_tracker.fill_(state_dict["_growth_tracker"]) self._growth_tracker.fill_(state_dict["_growth_tracker"])
def __getstate__(self): def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy() state = self.__dict__.copy()
if self._enabled: if self._enabled:
assert len(self._per_optimizer_states) == 0, ( assert len(self._per_optimizer_states) == 0, (
@ -632,10 +662,10 @@ class GradScaler:
state["_growth_tracker"] = None state["_growth_tracker"] = None
return state return state
def __setstate__(self, state): def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state) self.__dict__.update(state)
def _check_inf_per_device(self, optimizer): def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device) dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
@ -647,5 +677,5 @@ class GradScaler:
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
def _found_inf_per_device(self, optimizer): def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]

View File

@ -1,21 +1,20 @@
import logging import logging
from collections import abc, defaultdict from collections import abc, defaultdict
from typing import Dict, List, Optional, Union from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.cuda import FloatTensor # type: ignore[attr-defined]
from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
from torch.distributed.distributed_c10d import ProcessGroup from torch.distributed.distributed_c10d import ProcessGroup
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def _refresh_per_optimizer_state(): def _refresh_per_optimizer_state() -> Dict[str, Any]:
return {"stage": OptState.READY, "found_inf_per_device": {}} return {"stage": OptState.READY, "found_inf_per_device": {}}
def _is_supported_device(tensor: torch.Tensor): def _is_supported_device(tensor: torch.Tensor) -> bool:
return tensor.is_cuda or tensor.device.type in ("xla", "cpu") return tensor.is_cuda or tensor.device.type in ("xla", "cpu")
@ -88,7 +87,7 @@ class ShardedGradScaler(GradScaler):
growth_interval: int = 2000, growth_interval: int = 2000,
enabled: bool = True, enabled: bool = True,
process_group: Optional[ProcessGroup] = dist.group.WORLD, process_group: Optional[ProcessGroup] = dist.group.WORLD,
): ) -> None:
super().__init__( super().__init__(
init_scale=init_scale, init_scale=init_scale,
backoff_factor=backoff_factor, backoff_factor=backoff_factor,
@ -100,9 +99,25 @@ class ShardedGradScaler(GradScaler):
self.process_group = process_group self.process_group = process_group
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
@overload
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
...
@overload
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
...
@overload
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
...
def scale( def scale(
self, outputs: Union[torch.Tensor, List[torch.Tensor]] self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
if not self._enabled: if not self._enabled:
return outputs return outputs
@ -121,9 +136,7 @@ class ShardedGradScaler(GradScaler):
stash: List[_GeneralMultiDeviceReplicator] = [] stash: List[_GeneralMultiDeviceReplicator] = []
def apply_scale( def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
val: Union[torch.Tensor, abc.Iterable]
) -> Union[torch.Tensor, abc.Iterable]:
if isinstance(val, torch.Tensor): if isinstance(val, torch.Tensor):
assert _is_supported_device(val) assert _is_supported_device(val)
if len(stash) == 0: if len(stash) == 0:
@ -136,19 +149,20 @@ class ShardedGradScaler(GradScaler):
# For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
# format (fp16, bf16) and so the scaled loss should be of the same dtype. # format (fp16, bf16) and so the scaled loss should be of the same dtype.
return scaled_val.type(val.dtype) return scaled_val.type(val.dtype)
elif isinstance(val, abc.Iterable): if isinstance(val, abc.Iterable):
iterator = map(apply_scale, val) iterator = map(apply_scale, val)
if isinstance(val, (list, tuple)): if isinstance(val, (list, tuple)):
return type(val)(iterator) return type(val)(iterator)
else: return iterator
return iterator raise ValueError("outputs must be a Tensor or an iterable of Tensors")
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs) # type: ignore[return-value] return apply_scale(outputs)
def _foreach_non_finite_check_and_unscale_cpu_( def _foreach_non_finite_check_and_unscale_cpu_(
self, grads: List, found_inf: torch.Tensor, inv_scale: torch.Tensor self,
grads: Sequence[torch.Tensor],
found_inf: torch.Tensor,
inv_scale: torch.Tensor,
) -> None: ) -> None:
if len(grads) == 0: if len(grads) == 0:
return return
@ -288,28 +302,25 @@ class ShardedGradScaler(GradScaler):
if future_handles: if future_handles:
torch.futures.wait_all(future_handles) torch.futures.wait_all(future_handles)
def step( def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None:
self, optimizer: torch.optim.Optimizer, *args, **kwargs
) -> Optional[float]:
return super().step(optimizer, *args, **kwargs)
def _amp_update_scale_cpu_(self, found_inf) -> None:
""" """
If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero. If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
Otherwise, scale is multiplied by the growth factor when the growth interval is reached. Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
""" """
assert self._scale is not None and self._growth_tracker is not None
if found_inf.item() >= 1.0: if found_inf.item() >= 1.0:
self._scale *= self._backoff_factor # type: ignore[arg-type] self._scale *= self._backoff_factor
self._growth_tracker = 0 self._growth_tracker.fill_(0)
else: else:
successful = self._growth_tracker + 1 # type: ignore[operator] successful = self._growth_tracker + 1
if successful == self._growth_interval: # type: ignore[arg-type] if successful == self._growth_interval:
self._scale *= self._growth_factor # type: ignore[arg-type] self._scale *= self._growth_factor
self._growth_tracker = 0 self._growth_tracker.fill_(0)
else: else:
self._growth_tracker = successful self._growth_tracker = successful
def update(self, new_scale: Optional[Union[float, FloatTensor]] = None) -> None: def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
""" """
Updates the scale factor. Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``