diff --git a/docs/source/amp.rst b/docs/source/amp.rst index d0ef865b370..a9d98a0aa09 100644 --- a/docs/source/amp.rst +++ b/docs/source/amp.rst @@ -40,6 +40,10 @@ For CUDA and CPU, APIs are also provided separately: Autocasting ^^^^^^^^^^^ +.. currentmodule:: torch.amp.autocast_mode + +.. autofunction:: is_autocast_available + .. currentmodule:: torch .. autoclass:: autocast diff --git a/test/test_autocast.py b/test/test_autocast.py index 50549449328..5b1e38eff06 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -340,6 +340,8 @@ class TestTorchAutocast(TestCase): with self.assertRaisesRegex(RuntimeError, msg): with torch.autocast(device_type=dev): _ = torch.tensor(1) + with self.assertRaisesRegex(RuntimeError, msg): + assert torch.amp.is_autocast_available(device_type=dev) if __name__ == "__main__": diff --git a/torch/amp/__init__.py b/torch/amp/__init__.py index e0be6969755..2884dfeefe3 100644 --- a/torch/amp/__init__.py +++ b/torch/amp/__init__.py @@ -1,2 +1,7 @@ -from .autocast_mode import _enter_autocast, _exit_autocast, autocast +from .autocast_mode import ( + _enter_autocast, + _exit_autocast, + autocast, + is_autocast_available, +) from .grad_scaler import GradScaler diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 87ff709fcfb..523d8dc34d5 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -6,7 +6,19 @@ from typing import Any, Optional import torch from torch.types import _dtype -__all__ = ["autocast_decorator", "autocast"] +__all__ = ["autocast_decorator", "autocast", "is_autocast_available"] + + +def is_autocast_available(device_type: str) -> bool: + r""" + Return a bool indicating if autocast is available on :attr:`device_type`. + + Args: + device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and so on. + The type is the same as the `type` attribute of a :class:`torch.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + """ + return torch._C._is_autocast_available(device_type) def autocast_decorator(autocast_instance, func): @@ -199,7 +211,7 @@ class autocast: assert dtype is not None return self.device = device_type - if not torch._C._is_autocast_available(self.device): + if not is_autocast_available(self.device): raise RuntimeError( f"User specified an unsupported autocast device_type '{self.device}'" ) diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index ca0e39d5379..7c74b4c6be7 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -194,7 +194,7 @@ def set_device_states(devices, states) -> None: def _get_autocast_kwargs(device="cuda"): - if torch._C._is_autocast_available(device): + if torch.amp.is_autocast_available(device): device_autocast_kwargs = { "enabled": torch.is_autocast_enabled(device), "dtype": torch.get_autocast_dtype(device), @@ -289,7 +289,7 @@ class CheckpointFunction(torch.autograd.Function): device_autocast_ctx = device_module.amp.autocast( **ctx.device_autocast_kwargs - ) if torch._C._is_autocast_available(ctx.device) else contextlib.nullcontext() + ) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext() with torch.enable_grad(), device_autocast_ctx, \ torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): outputs = ctx.run_function(*detached_inputs) @@ -1396,7 +1396,7 @@ def _checkpoint_without_reentrant_generator( device_autocast_ctx = device_module.amp.autocast( **device_autocast_kwargs - ) if torch._C._is_autocast_available(device) else contextlib.nullcontext() + ) if torch.amp.is_autocast_available(device) else contextlib.nullcontext() with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ recompute_context: fn(*args, **kwargs)