mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improve torch.cuda.amp type hints (#108630)
Fixes #108629 1. Add the following to their modules' `__all__` so that pyright considers them to be publicly exported: * [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) * [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler) * [`torch.cuda.amp.autocast`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast) * [`torch.cuda.amp.custom_fwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd) * [`torch.cuda.amp.custom_bwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_bwd) 2. Add `overload`s for `torch.cuda.amp.GradScaler.scale` to differentiate when a `torch.Tensor` is returned vs. an `Iterable[torch.Tensor]` is returned based on the type of the `outputs` parameter. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108630 Approved by: https://github.com/ezyang
This commit is contained in:
parent
6c7260407b
commit
e40d6ae0a7
|
|
@ -332,7 +332,7 @@ def generate_tensor_like_torch_implementations():
|
||||||
# the problem. A more proper fix is to make the "not tested" check
|
# the problem. A more proper fix is to make the "not tested" check
|
||||||
# a test on its own, and to make sure the monkeypatch is only installed
|
# a test on its own, and to make sure the monkeypatch is only installed
|
||||||
# for the span of the relevant test (and deleted afterwards)
|
# for the span of the relevant test (and deleted afterwards)
|
||||||
testing_ignore = {"sample_functional"}
|
testing_ignore = {"sample_functional", "autocast"}
|
||||||
for namespace, funcs in get_overridable_functions().items():
|
for namespace, funcs in get_overridable_functions().items():
|
||||||
for func in funcs:
|
for func in funcs:
|
||||||
if func not in testing_overrides and func.__name__ not in testing_ignore:
|
if func not in testing_overrides and func.__name__ not in testing_ignore:
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ __all__ = [
|
||||||
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
||||||
'SymBool', 'sym_not',
|
'SymBool', 'sym_not',
|
||||||
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
|
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
|
||||||
'export',
|
'export', 'autocast',
|
||||||
]
|
]
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,9 @@
|
||||||
from .autocast_mode import autocast, custom_bwd, custom_fwd # noqa: F401
|
from .autocast_mode import autocast, custom_bwd, custom_fwd
|
||||||
from .grad_scaler import GradScaler # noqa: F401
|
from .grad_scaler import GradScaler
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"autocast",
|
||||||
|
"custom_bwd",
|
||||||
|
"custom_fwd",
|
||||||
|
"GradScaler",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from collections import abc, defaultdict
|
from collections import abc, defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, cast, Dict, List, Optional, Tuple
|
from typing import Any, cast, Dict, Iterable, List, Optional, overload, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from .common import amp_definitely_not_available
|
from .common import amp_definitely_not_available
|
||||||
|
|
@ -21,7 +23,7 @@ class _MultiDeviceReplicator:
|
||||||
self.master = master_tensor
|
self.master = master_tensor
|
||||||
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
||||||
|
|
||||||
def get(self, device) -> torch.Tensor:
|
def get(self, device: torch.device) -> torch.Tensor:
|
||||||
retval = self._per_device_tensors.get(device, None)
|
retval = self._per_device_tensors.get(device, None)
|
||||||
if retval is None:
|
if retval is None:
|
||||||
retval = self.master.to(device=device, non_blocking=True, copy=True)
|
retval = self.master.to(device=device, non_blocking=True, copy=True)
|
||||||
|
|
@ -40,14 +42,11 @@ class OptState(Enum):
|
||||||
STEPPED = 2
|
STEPPED = 2
|
||||||
|
|
||||||
|
|
||||||
def _refresh_per_optimizer_state():
|
def _refresh_per_optimizer_state() -> Dict[str, Any]:
|
||||||
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
||||||
|
|
||||||
|
|
||||||
class GradScaler:
|
class GradScaler:
|
||||||
_scale: Optional[torch.Tensor]
|
|
||||||
_grows_tracker: Optional[torch.Tensor]
|
|
||||||
_per_optimizer_states: Dict[int, Dict[str, Any]]
|
|
||||||
"""
|
"""
|
||||||
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
|
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
|
||||||
conveniently.
|
conveniently.
|
||||||
|
|
@ -115,12 +114,12 @@ class GradScaler:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
init_scale=2.0**16,
|
init_scale: float = 2.0**16,
|
||||||
growth_factor=2.0,
|
growth_factor: float = 2.0,
|
||||||
backoff_factor=0.5,
|
backoff_factor: float = 0.5,
|
||||||
growth_interval=2000,
|
growth_interval: int = 2000,
|
||||||
enabled=True,
|
enabled: bool = True,
|
||||||
):
|
) -> None:
|
||||||
if enabled and amp_definitely_not_available():
|
if enabled and amp_definitely_not_available():
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling."
|
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling."
|
||||||
|
|
@ -135,17 +134,19 @@ class GradScaler:
|
||||||
|
|
||||||
self._init_scale = init_scale
|
self._init_scale = init_scale
|
||||||
# self._scale will be lazily initialized during the first call to scale()
|
# self._scale will be lazily initialized during the first call to scale()
|
||||||
self._scale = None
|
self._scale: Optional[torch.Tensor] = None
|
||||||
self._growth_factor = growth_factor
|
self._growth_factor = growth_factor
|
||||||
self._backoff_factor = backoff_factor
|
self._backoff_factor = backoff_factor
|
||||||
self._growth_interval = growth_interval
|
self._growth_interval = growth_interval
|
||||||
self._init_growth_tracker = 0
|
self._init_growth_tracker = 0
|
||||||
# self._growth_tracker will be lazily initialized during the first call to scale()
|
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||||
self._growth_tracker = None
|
self._growth_tracker: Optional[torch.Tensor] = None
|
||||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
self._per_optimizer_states: Dict[int, Dict[str, Any]] = defaultdict(
|
||||||
|
_refresh_per_optimizer_state
|
||||||
|
)
|
||||||
|
|
||||||
def _check_scale_growth_tracker(
|
def _check_scale_growth_tracker(
|
||||||
self, funcname
|
self, funcname: str
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||||
assert self._scale is not None, (
|
assert self._scale is not None, (
|
||||||
|
|
@ -156,14 +157,33 @@ class GradScaler:
|
||||||
)
|
)
|
||||||
return (self._scale, self._growth_tracker)
|
return (self._scale, self._growth_tracker)
|
||||||
|
|
||||||
def _lazy_init_scale_growth_tracker(self, dev):
|
def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None:
|
||||||
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||||
self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
|
self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
|
||||||
self._growth_tracker = torch.full(
|
self._growth_tracker = torch.full(
|
||||||
(), self._init_growth_tracker, dtype=torch.int32, device=dev
|
(), self._init_growth_tracker, dtype=torch.int32, device=dev
|
||||||
)
|
)
|
||||||
|
|
||||||
def scale(self, outputs):
|
@overload
|
||||||
|
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def scale(
|
||||||
|
self,
|
||||||
|
outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
|
||||||
|
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||||
|
|
||||||
|
|
@ -189,7 +209,7 @@ class GradScaler:
|
||||||
_MultiDeviceReplicator
|
_MultiDeviceReplicator
|
||||||
] = [] # holds a reference that can be overwritten by apply_scale
|
] = [] # holds a reference that can be overwritten by apply_scale
|
||||||
|
|
||||||
def apply_scale(val):
|
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
|
||||||
if isinstance(val, torch.Tensor):
|
if isinstance(val, torch.Tensor):
|
||||||
assert val.is_cuda or val.device.type == "xla"
|
assert val.is_cuda or val.device.type == "xla"
|
||||||
if len(stash) == 0:
|
if len(stash) == 0:
|
||||||
|
|
@ -198,18 +218,22 @@ class GradScaler:
|
||||||
assert self._scale is not None
|
assert self._scale is not None
|
||||||
stash.append(_MultiDeviceReplicator(self._scale))
|
stash.append(_MultiDeviceReplicator(self._scale))
|
||||||
return val * stash[0].get(val.device)
|
return val * stash[0].get(val.device)
|
||||||
elif isinstance(val, abc.Iterable):
|
if isinstance(val, abc.Iterable):
|
||||||
iterable = map(apply_scale, val)
|
iterable = map(apply_scale, val)
|
||||||
if isinstance(val, (list, tuple)):
|
if isinstance(val, (list, tuple)):
|
||||||
return type(val)(iterable)
|
return type(val)(iterable)
|
||||||
else:
|
|
||||||
return iterable
|
return iterable
|
||||||
else:
|
|
||||||
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||||
|
|
||||||
return apply_scale(outputs)
|
return apply_scale(outputs)
|
||||||
|
|
||||||
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
|
def _unscale_grads_(
|
||||||
|
self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
inv_scale: torch.Tensor,
|
||||||
|
found_inf: torch.Tensor,
|
||||||
|
allow_fp16: bool,
|
||||||
|
) -> Dict[torch.device, torch.Tensor]:
|
||||||
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
||||||
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
||||||
|
|
||||||
|
|
@ -219,10 +243,13 @@ class GradScaler:
|
||||||
|
|
||||||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||||
# Google says mypy struggles with defaultdicts type annotations.
|
# Google says mypy struggles with defaultdicts type annotations.
|
||||||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
per_device_and_dtype_grads: Dict[
|
||||||
|
torch.device, Dict[torch.dtype, List[torch.Tensor]]
|
||||||
|
] = defaultdict(lambda: defaultdict(list))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
for param in group["params"]:
|
for param in group["params"]:
|
||||||
|
assert isinstance(param, torch.Tensor)
|
||||||
if param.grad is None:
|
if param.grad is None:
|
||||||
continue
|
continue
|
||||||
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
||||||
|
|
@ -253,7 +280,7 @@ class GradScaler:
|
||||||
|
|
||||||
return per_device_found_inf._per_device_tensors
|
return per_device_found_inf._per_device_tensors
|
||||||
|
|
||||||
def unscale_(self, optimizer):
|
def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
|
||||||
"""
|
"""
|
||||||
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||||
|
|
||||||
|
|
@ -309,13 +336,21 @@ class GradScaler:
|
||||||
)
|
)
|
||||||
optimizer_state["stage"] = OptState.UNSCALED
|
optimizer_state["stage"] = OptState.UNSCALED
|
||||||
|
|
||||||
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
|
def _maybe_opt_step(
|
||||||
retval = None
|
self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
optimizer_state: Dict[str, Any],
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Optional[float]:
|
||||||
|
retval: Optional[float] = None
|
||||||
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
|
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
|
||||||
retval = optimizer.step(*args, **kwargs)
|
retval = optimizer.step(*args, **kwargs)
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
def step(self, optimizer, *args, **kwargs):
|
def step(
|
||||||
|
self, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any
|
||||||
|
) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
:meth:`step` carries out the following two operations:
|
:meth:`step` carries out the following two operations:
|
||||||
|
|
||||||
|
|
@ -353,12 +388,9 @@ class GradScaler:
|
||||||
"step() has already been called since the last update()."
|
"step() has already been called since the last update()."
|
||||||
)
|
)
|
||||||
|
|
||||||
retval = None
|
retval: Optional[float] = None
|
||||||
|
|
||||||
if (
|
if getattr(optimizer, "_step_supports_amp_scaling", False):
|
||||||
hasattr(optimizer, "_step_supports_amp_scaling")
|
|
||||||
and optimizer._step_supports_amp_scaling
|
|
||||||
):
|
|
||||||
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
||||||
# The contract with custom optimizers is that their step() should accept an additional,
|
# The contract with custom optimizers is that their step() should accept an additional,
|
||||||
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
||||||
|
|
@ -386,6 +418,7 @@ class GradScaler:
|
||||||
if optimizer_state["stage"] is OptState.READY:
|
if optimizer_state["stage"] is OptState.READY:
|
||||||
self._check_inf_per_device(optimizer)
|
self._check_inf_per_device(optimizer)
|
||||||
scaler = self._get_scale_async()
|
scaler = self._get_scale_async()
|
||||||
|
assert scaler is not None
|
||||||
found_inf = cast(
|
found_inf = cast(
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
sum(
|
sum(
|
||||||
|
|
@ -395,15 +428,15 @@ class GradScaler:
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
optimizer.grad_scale = (
|
optimizer.grad_scale = ( # type: ignore[attr-defined]
|
||||||
None if optimizer_state["stage"] == OptState.UNSCALED else scaler
|
None if optimizer_state["stage"] == OptState.UNSCALED else scaler
|
||||||
)
|
)
|
||||||
optimizer.found_inf = found_inf
|
optimizer.found_inf = found_inf # type: ignore[attr-defined]
|
||||||
retval = optimizer.step(*args, **kwargs_)
|
retval = optimizer.step(*args, **kwargs_)
|
||||||
optimizer_state["stage"] = OptState.STEPPED
|
optimizer_state["stage"] = OptState.STEPPED
|
||||||
if not has_grad_scaler_kwarg:
|
if not has_grad_scaler_kwarg:
|
||||||
del optimizer.grad_scale
|
del optimizer.grad_scale # type: ignore[attr-defined]
|
||||||
del optimizer.found_inf
|
del optimizer.found_inf # type: ignore[attr-defined]
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
if optimizer_state["stage"] is OptState.READY:
|
if optimizer_state["stage"] is OptState.READY:
|
||||||
|
|
@ -419,7 +452,7 @@ class GradScaler:
|
||||||
|
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
def update(self, new_scale=None):
|
def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Updates the scale factor.
|
Updates the scale factor.
|
||||||
|
|
||||||
|
|
@ -451,15 +484,16 @@ class GradScaler:
|
||||||
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||||
|
|
||||||
if new_scale is not None:
|
if new_scale is not None:
|
||||||
|
assert self._scale is not None
|
||||||
# Accept a new user-defined scale.
|
# Accept a new user-defined scale.
|
||||||
if isinstance(new_scale, float):
|
if isinstance(new_scale, float):
|
||||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
self._scale.fill_(new_scale)
|
||||||
else:
|
else:
|
||||||
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
||||||
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
|
assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
|
||||||
assert new_scale.numel() == 1, reason
|
assert new_scale.numel() == 1, reason
|
||||||
assert new_scale.requires_grad is False, reason
|
assert new_scale.requires_grad is False, reason
|
||||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
self._scale.copy_(new_scale)
|
||||||
else:
|
else:
|
||||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||||
|
|
@ -488,10 +522,10 @@ class GradScaler:
|
||||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||||
|
|
||||||
def _get_scale_async(self):
|
def _get_scale_async(self) -> Optional[torch.Tensor]:
|
||||||
return self._scale
|
return self._scale
|
||||||
|
|
||||||
def get_scale(self):
|
def get_scale(self) -> float:
|
||||||
"""
|
"""
|
||||||
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
|
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
|
||||||
|
|
||||||
|
|
@ -501,68 +535,66 @@ class GradScaler:
|
||||||
if self._enabled:
|
if self._enabled:
|
||||||
return (
|
return (
|
||||||
self._init_scale
|
self._init_scale
|
||||||
if self._scale is None
|
if (scale := self._get_scale_async()) is None
|
||||||
else self._get_scale_async().item()
|
else cast(float, scale.item())
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
def get_growth_factor(self):
|
def get_growth_factor(self) -> float:
|
||||||
r"""
|
r"""
|
||||||
Returns a Python float containing the scale growth factor.
|
Returns a Python float containing the scale growth factor.
|
||||||
"""
|
"""
|
||||||
return self._growth_factor
|
return self._growth_factor
|
||||||
|
|
||||||
def set_growth_factor(self, new_factor):
|
def set_growth_factor(self, new_factor: float) -> None:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
new_scale (float): Value to use as the new scale growth factor.
|
new_scale (float): Value to use as the new scale growth factor.
|
||||||
"""
|
"""
|
||||||
self._growth_factor = new_factor
|
self._growth_factor = new_factor
|
||||||
|
|
||||||
def get_backoff_factor(self):
|
def get_backoff_factor(self) -> float:
|
||||||
r"""
|
r"""
|
||||||
Returns a Python float containing the scale backoff factor.
|
Returns a Python float containing the scale backoff factor.
|
||||||
"""
|
"""
|
||||||
return self._backoff_factor
|
return self._backoff_factor
|
||||||
|
|
||||||
def set_backoff_factor(self, new_factor):
|
def set_backoff_factor(self, new_factor: float) -> None:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
new_scale (float): Value to use as the new scale backoff factor.
|
new_scale (float): Value to use as the new scale backoff factor.
|
||||||
"""
|
"""
|
||||||
self._backoff_factor = new_factor
|
self._backoff_factor = new_factor
|
||||||
|
|
||||||
def get_growth_interval(self):
|
def get_growth_interval(self) -> int:
|
||||||
r"""
|
r"""
|
||||||
Returns a Python int containing the growth interval.
|
Returns a Python int containing the growth interval.
|
||||||
"""
|
"""
|
||||||
return self._growth_interval
|
return self._growth_interval
|
||||||
|
|
||||||
def set_growth_interval(self, new_interval):
|
def set_growth_interval(self, new_interval: int) -> None:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
new_interval (int): Value to use as the new growth interval.
|
new_interval (int): Value to use as the new growth interval.
|
||||||
"""
|
"""
|
||||||
self._growth_interval = new_interval
|
self._growth_interval = new_interval
|
||||||
|
|
||||||
def _get_growth_tracker(self):
|
def _get_growth_tracker(self) -> int:
|
||||||
if self._enabled:
|
if self._enabled:
|
||||||
return (
|
return (
|
||||||
self._init_growth_tracker
|
self._init_growth_tracker
|
||||||
if self._growth_tracker is None
|
if self._growth_tracker is None
|
||||||
else self._growth_tracker.item()
|
else cast(int, self._growth_tracker.item())
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def is_enabled(self):
|
def is_enabled(self) -> bool:
|
||||||
r"""
|
r"""
|
||||||
Returns a bool indicating whether this instance is enabled.
|
Returns a bool indicating whether this instance is enabled.
|
||||||
"""
|
"""
|
||||||
return self._enabled
|
return self._enabled
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self) -> Dict[str, Any]:
|
||||||
r"""
|
r"""
|
||||||
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
||||||
|
|
||||||
|
|
@ -578,19 +610,17 @@ class GradScaler:
|
||||||
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||||
should be called after :meth:`update`.
|
should be called after :meth:`update`.
|
||||||
"""
|
"""
|
||||||
return (
|
if self._enabled:
|
||||||
{
|
return {
|
||||||
"scale": self.get_scale(),
|
"scale": self.get_scale(),
|
||||||
"growth_factor": self._growth_factor,
|
"growth_factor": self._growth_factor,
|
||||||
"backoff_factor": self._backoff_factor,
|
"backoff_factor": self._backoff_factor,
|
||||||
"growth_interval": self._growth_interval,
|
"growth_interval": self._growth_interval,
|
||||||
"_growth_tracker": self._get_growth_tracker(),
|
"_growth_tracker": self._get_growth_tracker(),
|
||||||
}
|
}
|
||||||
if self._enabled
|
return {}
|
||||||
else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||||
r"""
|
r"""
|
||||||
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
||||||
|
|
||||||
|
|
@ -606,17 +636,17 @@ class GradScaler:
|
||||||
"from a disabled instance of GradScaler."
|
"from a disabled instance of GradScaler."
|
||||||
)
|
)
|
||||||
|
|
||||||
self._init_scale = state_dict["scale"]
|
self._init_scale = cast(float, state_dict["scale"])
|
||||||
if self._scale is not None:
|
if self._scale is not None:
|
||||||
self._scale.fill_(state_dict["scale"])
|
self._scale.fill_(state_dict["scale"])
|
||||||
self._growth_factor = state_dict["growth_factor"]
|
self._growth_factor = cast(float, state_dict["growth_factor"])
|
||||||
self._backoff_factor = state_dict["backoff_factor"]
|
self._backoff_factor = cast(float, state_dict["backoff_factor"])
|
||||||
self._growth_interval = state_dict["growth_interval"]
|
self._growth_interval = cast(int, state_dict["growth_interval"])
|
||||||
self._init_growth_tracker = state_dict["_growth_tracker"]
|
self._init_growth_tracker = cast(int, state_dict["_growth_tracker"])
|
||||||
if self._growth_tracker is not None:
|
if self._growth_tracker is not None:
|
||||||
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self) -> Dict[str, Any]:
|
||||||
state = self.__dict__.copy()
|
state = self.__dict__.copy()
|
||||||
if self._enabled:
|
if self._enabled:
|
||||||
assert len(self._per_optimizer_states) == 0, (
|
assert len(self._per_optimizer_states) == 0, (
|
||||||
|
|
@ -632,10 +662,10 @@ class GradScaler:
|
||||||
state["_growth_tracker"] = None
|
state["_growth_tracker"] = None
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
|
|
||||||
def _check_inf_per_device(self, optimizer):
|
def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
|
||||||
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||||
|
|
||||||
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
|
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
|
||||||
|
|
@ -647,5 +677,5 @@ class GradScaler:
|
||||||
|
|
||||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||||
|
|
||||||
def _found_inf_per_device(self, optimizer):
|
def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
|
||||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,20 @@
|
||||||
import logging
|
import logging
|
||||||
from collections import abc, defaultdict
|
from collections import abc, defaultdict
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.cuda import FloatTensor # type: ignore[attr-defined]
|
|
||||||
from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
|
from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
|
||||||
from torch.distributed.distributed_c10d import ProcessGroup
|
from torch.distributed.distributed_c10d import ProcessGroup
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _refresh_per_optimizer_state():
|
def _refresh_per_optimizer_state() -> Dict[str, Any]:
|
||||||
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
||||||
|
|
||||||
|
|
||||||
def _is_supported_device(tensor: torch.Tensor):
|
def _is_supported_device(tensor: torch.Tensor) -> bool:
|
||||||
return tensor.is_cuda or tensor.device.type in ("xla", "cpu")
|
return tensor.is_cuda or tensor.device.type in ("xla", "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -88,7 +87,7 @@ class ShardedGradScaler(GradScaler):
|
||||||
growth_interval: int = 2000,
|
growth_interval: int = 2000,
|
||||||
enabled: bool = True,
|
enabled: bool = True,
|
||||||
process_group: Optional[ProcessGroup] = dist.group.WORLD,
|
process_group: Optional[ProcessGroup] = dist.group.WORLD,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
init_scale=init_scale,
|
init_scale=init_scale,
|
||||||
backoff_factor=backoff_factor,
|
backoff_factor=backoff_factor,
|
||||||
|
|
@ -100,9 +99,25 @@ class ShardedGradScaler(GradScaler):
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
|
||||||
|
...
|
||||||
|
|
||||||
def scale(
|
def scale(
|
||||||
self, outputs: Union[torch.Tensor, List[torch.Tensor]]
|
self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
|
||||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
@ -121,9 +136,7 @@ class ShardedGradScaler(GradScaler):
|
||||||
|
|
||||||
stash: List[_GeneralMultiDeviceReplicator] = []
|
stash: List[_GeneralMultiDeviceReplicator] = []
|
||||||
|
|
||||||
def apply_scale(
|
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
|
||||||
val: Union[torch.Tensor, abc.Iterable]
|
|
||||||
) -> Union[torch.Tensor, abc.Iterable]:
|
|
||||||
if isinstance(val, torch.Tensor):
|
if isinstance(val, torch.Tensor):
|
||||||
assert _is_supported_device(val)
|
assert _is_supported_device(val)
|
||||||
if len(stash) == 0:
|
if len(stash) == 0:
|
||||||
|
|
@ -136,19 +149,20 @@ class ShardedGradScaler(GradScaler):
|
||||||
# For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
|
# For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
|
||||||
# format (fp16, bf16) and so the scaled loss should be of the same dtype.
|
# format (fp16, bf16) and so the scaled loss should be of the same dtype.
|
||||||
return scaled_val.type(val.dtype)
|
return scaled_val.type(val.dtype)
|
||||||
elif isinstance(val, abc.Iterable):
|
if isinstance(val, abc.Iterable):
|
||||||
iterator = map(apply_scale, val)
|
iterator = map(apply_scale, val)
|
||||||
if isinstance(val, (list, tuple)):
|
if isinstance(val, (list, tuple)):
|
||||||
return type(val)(iterator)
|
return type(val)(iterator)
|
||||||
else:
|
|
||||||
return iterator
|
return iterator
|
||||||
else:
|
|
||||||
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||||
|
|
||||||
return apply_scale(outputs) # type: ignore[return-value]
|
return apply_scale(outputs)
|
||||||
|
|
||||||
def _foreach_non_finite_check_and_unscale_cpu_(
|
def _foreach_non_finite_check_and_unscale_cpu_(
|
||||||
self, grads: List, found_inf: torch.Tensor, inv_scale: torch.Tensor
|
self,
|
||||||
|
grads: Sequence[torch.Tensor],
|
||||||
|
found_inf: torch.Tensor,
|
||||||
|
inv_scale: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
if len(grads) == 0:
|
if len(grads) == 0:
|
||||||
return
|
return
|
||||||
|
|
@ -288,28 +302,25 @@ class ShardedGradScaler(GradScaler):
|
||||||
if future_handles:
|
if future_handles:
|
||||||
torch.futures.wait_all(future_handles)
|
torch.futures.wait_all(future_handles)
|
||||||
|
|
||||||
def step(
|
def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None:
|
||||||
self, optimizer: torch.optim.Optimizer, *args, **kwargs
|
|
||||||
) -> Optional[float]:
|
|
||||||
return super().step(optimizer, *args, **kwargs)
|
|
||||||
|
|
||||||
def _amp_update_scale_cpu_(self, found_inf) -> None:
|
|
||||||
"""
|
"""
|
||||||
If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
|
If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
|
||||||
Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
|
Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
|
||||||
"""
|
"""
|
||||||
|
assert self._scale is not None and self._growth_tracker is not None
|
||||||
|
|
||||||
if found_inf.item() >= 1.0:
|
if found_inf.item() >= 1.0:
|
||||||
self._scale *= self._backoff_factor # type: ignore[arg-type]
|
self._scale *= self._backoff_factor
|
||||||
self._growth_tracker = 0
|
self._growth_tracker.fill_(0)
|
||||||
else:
|
else:
|
||||||
successful = self._growth_tracker + 1 # type: ignore[operator]
|
successful = self._growth_tracker + 1
|
||||||
if successful == self._growth_interval: # type: ignore[arg-type]
|
if successful == self._growth_interval:
|
||||||
self._scale *= self._growth_factor # type: ignore[arg-type]
|
self._scale *= self._growth_factor
|
||||||
self._growth_tracker = 0
|
self._growth_tracker.fill_(0)
|
||||||
else:
|
else:
|
||||||
self._growth_tracker = successful
|
self._growth_tracker = successful
|
||||||
|
|
||||||
def update(self, new_scale: Optional[Union[float, FloatTensor]] = None) -> None:
|
def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Updates the scale factor.
|
Updates the scale factor.
|
||||||
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user