pytorch/torch/cpu/amp/grad_scaler.py
2024-05-31 19:47:24 +00:00

33 lines
898 B
Python

import warnings
import torch
__all__ = ["GradScaler"]
class GradScaler(torch.amp.GradScaler):
r"""
See :class:`torch.amp.GradScaler`.
``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cpu", args...)`` instead.
"""
def __init__(
self,
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
enabled: bool = True,
) -> None:
warnings.warn(
"torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead."
)
super().__init__(
"cpu",
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
enabled=enabled,
)