[AMP][Refactor] Autocast dtype handling to simplify device-specific c… (#165221)

This PR refactors the autocast context manager in autocast_mode.py to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping device_supported_dtypes is used to associate device types with their supported dtypes, and the validation logic is unified.

**The former PR #163446 was merged but reverted due to failed CI test on `openreg` related tests.**

This RR additionally slightly modified some test assertions for passing the CI tests. CI failed due to assertion for the exactly same error message. For example:
```
File "/var/lib/jenkins/workspace/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py", line 9, in test_autocast_with_unsupported_type
    with self.assertWarnsRegex(
        AssertionError: "In openreg autocast, but the target dtype torch.float32 is not supported." does not match "In openreg autocast, but the target dtype is not supported. Disabling autocast."
```

Sorry for the inconvenience again.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165221
Approved by: https://github.com/albanD
This commit is contained in:
KarhouTam 2025-10-28 06:21:29 +00:00 committed by PyTorch MergeBot
parent 5d0b3e28dc
commit 13413b3b07
3 changed files with 47 additions and 98 deletions

View File

@ -8,7 +8,8 @@ class TestAutocast(TestCase):
def test_autocast_with_unsupported_type(self): def test_autocast_with_unsupported_type(self):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"In openreg autocast, but the target dtype torch.float32 is not supported.", "In openreg autocast, but the target dtype is not supported. Disabling autocast.\n"
"openreg Autocast only supports dtypes of torch.float16, torch.bfloat16 currently.",
): ):
with torch.autocast(device_type="openreg", dtype=torch.float32): with torch.autocast(device_type="openreg", dtype=torch.float32):
_ = torch.ones(10) _ = torch.ones(10)

View File

@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
def test_mps_autocast_error_message(self): def test_mps_autocast_error_message(self):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.", "MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
): ):
with torch.autocast(device_type="mps", dtype=torch.float32): with torch.autocast(device_type="mps", dtype=torch.float32):
_ = torch.ones(10) _ = torch.ones(10)

View File

@ -230,9 +230,9 @@ class autocast:
raise ValueError( raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`" f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
) )
if dtype is None: self.fast_dtype = (
dtype = torch.get_autocast_dtype(device_type) torch.get_autocast_dtype(device_type) if dtype is None else dtype
self.fast_dtype = dtype )
if torch._jit_internal.is_scripting(): if torch._jit_internal.is_scripting():
self._enabled = enabled self._enabled = enabled
self.device = device_type self.device = device_type
@ -243,6 +243,9 @@ class autocast:
raise RuntimeError( raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'" f"User specified an unsupported autocast device_type '{self.device}'"
) )
device_supported_dtypes = [torch.bfloat16, torch.float16]
self.custom_backend_name = torch._C._get_privateuse1_backend_name() self.custom_backend_name = torch._C._get_privateuse1_backend_name()
if self.device == self.custom_backend_name: if self.device == self.custom_backend_name:
necessary_funcs = [ necessary_funcs = [
@ -259,111 +262,56 @@ class autocast:
assert hasattr(self.custom_device_mod, func), ( assert hasattr(self.custom_device_mod, func), (
message + f"But the func `{func}` is missing. \n" message + f"But the func `{func}` is missing. \n"
) )
device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype()
self._cache_enabled = torch.is_autocast_cache_enabled() self._cache_enabled = (
if ( torch.is_autocast_cache_enabled()
enabled if cache_enabled is None
and self.device == "cuda" else cache_enabled
and torch.cuda.amp.common.amp_definitely_not_available() )
):
device_name = (
self.device
if self.device == self.custom_backend_name
else self.device.upper()
)
if enabled:
# Special case for CUDA AMP and bfloat16 support
if self.device == "cuda":
if torch.cuda.amp.common.amp_definitely_not_available():
warnings.warn( warnings.warn(
"User provided device_type of 'cuda', but CUDA is not available. Disabling", "CUDA is not available or torch_xla is imported. Disabling autocast.",
stacklevel=2, stacklevel=2,
) )
enabled = False enabled = False
if cache_enabled is not None: elif (
self._cache_enabled = cache_enabled self.fast_dtype == torch.bfloat16
if self.device == "cpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype and enabled:
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "CPU Autocast only supports dtype of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message, stacklevel=2)
enabled = False
elif self.device == "mtia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message, stacklevel=2)
enabled = False
elif self.device == "maia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message, stacklevel=2)
enabled = False
elif self.device == "xpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message, stacklevel=2)
enabled = False
elif self.device == "ipu":
supported_dtypes = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtypes:
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message, stacklevel=2)
enabled = False
elif self.device == "hpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message, stacklevel=2)
enabled = False
elif self.device == self.custom_backend_name:
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
if self.fast_dtype not in supported_dtype:
error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. "
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message, stacklevel=2)
enabled = False
elif self.device == "cuda":
if (
enabled
and self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported() and not torch.cuda.is_bf16_supported()
): ):
raise RuntimeError( raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16." "Current CUDA Device does not support bfloat16. Please switch dtype to float16."
) )
elif self.device == "mps": elif self.fast_dtype not in device_supported_dtypes:
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = ( error_message = (
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n" f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n"
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently." f"{device_name} Autocast only supports dtypes of "
+ ", ".join(map(str, device_supported_dtypes))
+ " currently."
) )
warnings.warn(error_message, stacklevel=2) warnings.warn(error_message, stacklevel=2)
enabled = False enabled = False
elif self.fast_dtype == torch.bfloat16: # Special case for MPS bfloat16 support on macOS < 14
if not torch.backends.mps.is_macos_or_newer(14, 0): if (
self.device == "mps"
and self.fast_dtype == torch.bfloat16
and not torch.backends.mps.is_macos_or_newer(14, 0)
):
error_message = ( error_message = (
"In MPS autocast, but the target dtype torch.bfloat16 is not supported " "In MPS autocast, but the target dtype torch.bfloat16 is not supported "
"on macOS versions below 14. Disabling autocast." "on macOS versions below 14. Disabling autocast."
) )
warnings.warn(error_message, stacklevel=2) warnings.warn(error_message, stacklevel=2)
enabled = False enabled = False
elif self.device == "xla":
supported_dtype = [torch.float16, torch.bfloat16]
if self.fast_dtype not in supported_dtype:
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"XLA Autocast only supports dtype of torch.bfloat16 currently."
)
warnings.warn(error_message, stacklevel=2)
enabled = False
self._enabled = enabled self._enabled = enabled
def __enter__(self): def __enter__(self):