mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add new API torch.amp.is_autocast_available (#124938)
# Motivation expose `torch._is_autocast_available` to `torch.amp.is_autocast_available` as a public api. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124938 Approved by: https://github.com/albanD
This commit is contained in:
parent
a46c27d961
commit
19a83eacb5
|
|
@ -40,6 +40,10 @@ For CUDA and CPU, APIs are also provided separately:
|
|||
|
||||
Autocasting
|
||||
^^^^^^^^^^^
|
||||
.. currentmodule:: torch.amp.autocast_mode
|
||||
|
||||
.. autofunction:: is_autocast_available
|
||||
|
||||
.. currentmodule:: torch
|
||||
|
||||
.. autoclass:: autocast
|
||||
|
|
|
|||
|
|
@ -340,6 +340,8 @@ class TestTorchAutocast(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
with torch.autocast(device_type=dev):
|
||||
_ = torch.tensor(1)
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
assert torch.amp.is_autocast_available(device_type=dev)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,2 +1,7 @@
|
|||
from .autocast_mode import _enter_autocast, _exit_autocast, autocast
|
||||
from .autocast_mode import (
|
||||
_enter_autocast,
|
||||
_exit_autocast,
|
||||
autocast,
|
||||
is_autocast_available,
|
||||
)
|
||||
from .grad_scaler import GradScaler
|
||||
|
|
|
|||
|
|
@ -6,7 +6,19 @@ from typing import Any, Optional
|
|||
import torch
|
||||
from torch.types import _dtype
|
||||
|
||||
__all__ = ["autocast_decorator", "autocast"]
|
||||
__all__ = ["autocast_decorator", "autocast", "is_autocast_available"]
|
||||
|
||||
|
||||
def is_autocast_available(device_type: str) -> bool:
|
||||
r"""
|
||||
Return a bool indicating if autocast is available on :attr:`device_type`.
|
||||
|
||||
Args:
|
||||
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and so on.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
"""
|
||||
return torch._C._is_autocast_available(device_type)
|
||||
|
||||
|
||||
def autocast_decorator(autocast_instance, func):
|
||||
|
|
@ -199,7 +211,7 @@ class autocast:
|
|||
assert dtype is not None
|
||||
return
|
||||
self.device = device_type
|
||||
if not torch._C._is_autocast_available(self.device):
|
||||
if not is_autocast_available(self.device):
|
||||
raise RuntimeError(
|
||||
f"User specified an unsupported autocast device_type '{self.device}'"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ def set_device_states(devices, states) -> None:
|
|||
|
||||
|
||||
def _get_autocast_kwargs(device="cuda"):
|
||||
if torch._C._is_autocast_available(device):
|
||||
if torch.amp.is_autocast_available(device):
|
||||
device_autocast_kwargs = {
|
||||
"enabled": torch.is_autocast_enabled(device),
|
||||
"dtype": torch.get_autocast_dtype(device),
|
||||
|
|
@ -289,7 +289,7 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
|
||||
device_autocast_ctx = device_module.amp.autocast(
|
||||
**ctx.device_autocast_kwargs
|
||||
) if torch._C._is_autocast_available(ctx.device) else contextlib.nullcontext()
|
||||
) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext()
|
||||
with torch.enable_grad(), device_autocast_ctx, \
|
||||
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
|
@ -1396,7 +1396,7 @@ def _checkpoint_without_reentrant_generator(
|
|||
|
||||
device_autocast_ctx = device_module.amp.autocast(
|
||||
**device_autocast_kwargs
|
||||
) if torch._C._is_autocast_available(device) else contextlib.nullcontext()
|
||||
) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
|
||||
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
|
||||
recompute_context:
|
||||
fn(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user