add new API torch.amp.is_autocast_available (#124938)

# Motivation
expose `torch._is_autocast_available` to `torch.amp.is_autocast_available` as a public api.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124938
Approved by: https://github.com/albanD
This commit is contained in:
Yu, Guangye 2024-04-26 13:04:14 +00:00 committed by PyTorch MergeBot
parent a46c27d961
commit 19a83eacb5
5 changed files with 29 additions and 6 deletions

View File

@ -40,6 +40,10 @@ For CUDA and CPU, APIs are also provided separately:
Autocasting
^^^^^^^^^^^
.. currentmodule:: torch.amp.autocast_mode
.. autofunction:: is_autocast_available
.. currentmodule:: torch
.. autoclass:: autocast

View File

@ -340,6 +340,8 @@ class TestTorchAutocast(TestCase):
with self.assertRaisesRegex(RuntimeError, msg):
with torch.autocast(device_type=dev):
_ = torch.tensor(1)
with self.assertRaisesRegex(RuntimeError, msg):
assert torch.amp.is_autocast_available(device_type=dev)
if __name__ == "__main__":

View File

@ -1,2 +1,7 @@
from .autocast_mode import _enter_autocast, _exit_autocast, autocast
from .autocast_mode import (
_enter_autocast,
_exit_autocast,
autocast,
is_autocast_available,
)
from .grad_scaler import GradScaler

View File

@ -6,7 +6,19 @@ from typing import Any, Optional
import torch
from torch.types import _dtype
__all__ = ["autocast_decorator", "autocast"]
__all__ = ["autocast_decorator", "autocast", "is_autocast_available"]
def is_autocast_available(device_type: str) -> bool:
r"""
Return a bool indicating if autocast is available on :attr:`device_type`.
Args:
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
"""
return torch._C._is_autocast_available(device_type)
def autocast_decorator(autocast_instance, func):
@ -199,7 +211,7 @@ class autocast:
assert dtype is not None
return
self.device = device_type
if not torch._C._is_autocast_available(self.device):
if not is_autocast_available(self.device):
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)

View File

@ -194,7 +194,7 @@ def set_device_states(devices, states) -> None:
def _get_autocast_kwargs(device="cuda"):
if torch._C._is_autocast_available(device):
if torch.amp.is_autocast_available(device):
device_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(device),
"dtype": torch.get_autocast_dtype(device),
@ -289,7 +289,7 @@ class CheckpointFunction(torch.autograd.Function):
device_autocast_ctx = device_module.amp.autocast(
**ctx.device_autocast_kwargs
) if torch._C._is_autocast_available(ctx.device) else contextlib.nullcontext()
) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext()
with torch.enable_grad(), device_autocast_ctx, \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
@ -1396,7 +1396,7 @@ def _checkpoint_without_reentrant_generator(
device_autocast_ctx = device_module.amp.autocast(
**device_autocast_kwargs
) if torch._C._is_autocast_available(device) else contextlib.nullcontext()
) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
fn(*args, **kwargs)