mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add new API torch.amp.is_autocast_available (#124938)
# Motivation expose `torch._is_autocast_available` to `torch.amp.is_autocast_available` as a public api. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124938 Approved by: https://github.com/albanD
This commit is contained in:
parent
a46c27d961
commit
19a83eacb5
|
|
@ -40,6 +40,10 @@ For CUDA and CPU, APIs are also provided separately:
|
||||||
|
|
||||||
Autocasting
|
Autocasting
|
||||||
^^^^^^^^^^^
|
^^^^^^^^^^^
|
||||||
|
.. currentmodule:: torch.amp.autocast_mode
|
||||||
|
|
||||||
|
.. autofunction:: is_autocast_available
|
||||||
|
|
||||||
.. currentmodule:: torch
|
.. currentmodule:: torch
|
||||||
|
|
||||||
.. autoclass:: autocast
|
.. autoclass:: autocast
|
||||||
|
|
|
||||||
|
|
@ -340,6 +340,8 @@ class TestTorchAutocast(TestCase):
|
||||||
with self.assertRaisesRegex(RuntimeError, msg):
|
with self.assertRaisesRegex(RuntimeError, msg):
|
||||||
with torch.autocast(device_type=dev):
|
with torch.autocast(device_type=dev):
|
||||||
_ = torch.tensor(1)
|
_ = torch.tensor(1)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, msg):
|
||||||
|
assert torch.amp.is_autocast_available(device_type=dev)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,7 @@
|
||||||
from .autocast_mode import _enter_autocast, _exit_autocast, autocast
|
from .autocast_mode import (
|
||||||
|
_enter_autocast,
|
||||||
|
_exit_autocast,
|
||||||
|
autocast,
|
||||||
|
is_autocast_available,
|
||||||
|
)
|
||||||
from .grad_scaler import GradScaler
|
from .grad_scaler import GradScaler
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,19 @@ from typing import Any, Optional
|
||||||
import torch
|
import torch
|
||||||
from torch.types import _dtype
|
from torch.types import _dtype
|
||||||
|
|
||||||
__all__ = ["autocast_decorator", "autocast"]
|
__all__ = ["autocast_decorator", "autocast", "is_autocast_available"]
|
||||||
|
|
||||||
|
|
||||||
|
def is_autocast_available(device_type: str) -> bool:
|
||||||
|
r"""
|
||||||
|
Return a bool indicating if autocast is available on :attr:`device_type`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and so on.
|
||||||
|
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||||
|
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||||
|
"""
|
||||||
|
return torch._C._is_autocast_available(device_type)
|
||||||
|
|
||||||
|
|
||||||
def autocast_decorator(autocast_instance, func):
|
def autocast_decorator(autocast_instance, func):
|
||||||
|
|
@ -199,7 +211,7 @@ class autocast:
|
||||||
assert dtype is not None
|
assert dtype is not None
|
||||||
return
|
return
|
||||||
self.device = device_type
|
self.device = device_type
|
||||||
if not torch._C._is_autocast_available(self.device):
|
if not is_autocast_available(self.device):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"User specified an unsupported autocast device_type '{self.device}'"
|
f"User specified an unsupported autocast device_type '{self.device}'"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -194,7 +194,7 @@ def set_device_states(devices, states) -> None:
|
||||||
|
|
||||||
|
|
||||||
def _get_autocast_kwargs(device="cuda"):
|
def _get_autocast_kwargs(device="cuda"):
|
||||||
if torch._C._is_autocast_available(device):
|
if torch.amp.is_autocast_available(device):
|
||||||
device_autocast_kwargs = {
|
device_autocast_kwargs = {
|
||||||
"enabled": torch.is_autocast_enabled(device),
|
"enabled": torch.is_autocast_enabled(device),
|
||||||
"dtype": torch.get_autocast_dtype(device),
|
"dtype": torch.get_autocast_dtype(device),
|
||||||
|
|
@ -289,7 +289,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
|
|
||||||
device_autocast_ctx = device_module.amp.autocast(
|
device_autocast_ctx = device_module.amp.autocast(
|
||||||
**ctx.device_autocast_kwargs
|
**ctx.device_autocast_kwargs
|
||||||
) if torch._C._is_autocast_available(ctx.device) else contextlib.nullcontext()
|
) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext()
|
||||||
with torch.enable_grad(), device_autocast_ctx, \
|
with torch.enable_grad(), device_autocast_ctx, \
|
||||||
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
||||||
outputs = ctx.run_function(*detached_inputs)
|
outputs = ctx.run_function(*detached_inputs)
|
||||||
|
|
@ -1396,7 +1396,7 @@ def _checkpoint_without_reentrant_generator(
|
||||||
|
|
||||||
device_autocast_ctx = device_module.amp.autocast(
|
device_autocast_ctx = device_module.amp.autocast(
|
||||||
**device_autocast_kwargs
|
**device_autocast_kwargs
|
||||||
) if torch._C._is_autocast_available(device) else contextlib.nullcontext()
|
) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
|
||||||
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
|
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
|
||||||
recompute_context:
|
recompute_context:
|
||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user