mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This reverts commit 749a132fb0.
Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
33 lines
898 B
Python
33 lines
898 B
Python
import warnings
|
|
|
|
import torch
|
|
|
|
__all__ = ["GradScaler"]
|
|
|
|
|
|
class GradScaler(torch.amp.GradScaler):
|
|
r"""
|
|
See :class:`torch.amp.GradScaler`.
|
|
``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cpu", args...)`` instead.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
init_scale: float = 2.0**16,
|
|
growth_factor: float = 2.0,
|
|
backoff_factor: float = 0.5,
|
|
growth_interval: int = 2000,
|
|
enabled: bool = True,
|
|
) -> None:
|
|
warnings.warn(
|
|
"torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead."
|
|
)
|
|
super().__init__(
|
|
"cpu",
|
|
init_scale=init_scale,
|
|
growth_factor=growth_factor,
|
|
backoff_factor=backoff_factor,
|
|
growth_interval=growth_interval,
|
|
enabled=enabled,
|
|
)
|