mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
# Motivation
## for `torch.amp.GradScaler`,
- `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`.
- `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`.
So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`.
## for `custom_fwd` and `custom_bwd`,
this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU.
So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`.
# Additional Context
Add UT to cover the deprecated warning.
No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them.
To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126527
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/janeyx99, https://github.com/EikanWang
33 lines
898 B
Python
33 lines
898 B
Python
import warnings
|
|
|
|
import torch
|
|
|
|
__all__ = ["GradScaler"]
|
|
|
|
|
|
class GradScaler(torch.amp.GradScaler):
|
|
r"""
|
|
See :class:`torch.amp.GradScaler`.
|
|
``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cpu", args...)`` instead.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
init_scale: float = 2.0**16,
|
|
growth_factor: float = 2.0,
|
|
backoff_factor: float = 0.5,
|
|
growth_interval: int = 2000,
|
|
enabled: bool = True,
|
|
) -> None:
|
|
warnings.warn(
|
|
"torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead."
|
|
)
|
|
super().__init__(
|
|
"cpu",
|
|
init_scale=init_scale,
|
|
growth_factor=growth_factor,
|
|
backoff_factor=backoff_factor,
|
|
growth_interval=growth_interval,
|
|
enabled=enabled,
|
|
)
|