From 13413b3b07cc72fa9c2671b2535f7e54c1b30ca2 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Tue, 28 Oct 2025 06:21:29 +0000 Subject: [PATCH] =?UTF-8?q?[AMP][Refactor]=20Autocast=20dtype=20handling?= =?UTF-8?q?=20to=20simplify=20device-specific=20c=E2=80=A6=20(#165221)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR refactors the autocast context manager in autocast_mode.py to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping device_supported_dtypes is used to associate device types with their supported dtypes, and the validation logic is unified. **The former PR #163446 was merged but reverted due to failed CI test on `openreg` related tests.** This RR additionally slightly modified some test assertions for passing the CI tests. CI failed due to assertion for the exactly same error message. For example: ``` File "/var/lib/jenkins/workspace/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py", line 9, in test_autocast_with_unsupported_type with self.assertWarnsRegex( AssertionError: "In openreg autocast, but the target dtype torch.float32 is not supported." does not match "In openreg autocast, but the target dtype is not supported. Disabling autocast." ``` Sorry for the inconvenience again. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165221 Approved by: https://github.com/albanD --- .../torch_openreg/tests/test_autocast.py | 3 +- test/test_autocast.py | 2 +- torch/amp/autocast_mode.py | 140 ++++++------------ 3 files changed, 47 insertions(+), 98 deletions(-) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py index 01423741660..6474a349ab4 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py @@ -8,7 +8,8 @@ class TestAutocast(TestCase): def test_autocast_with_unsupported_type(self): with self.assertWarnsRegex( UserWarning, - "In openreg autocast, but the target dtype torch.float32 is not supported.", + "In openreg autocast, but the target dtype is not supported. Disabling autocast.\n" + "openreg Autocast only supports dtypes of torch.float16, torch.bfloat16 currently.", ): with torch.autocast(device_type="openreg", dtype=torch.float32): _ = torch.ones(10) diff --git a/test/test_autocast.py b/test/test_autocast.py index 19e05dd0a9d..8e057c363cf 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase): def test_mps_autocast_error_message(self): with self.assertWarnsRegex( UserWarning, - "MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.", + "MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.", ): with torch.autocast(device_type="mps", dtype=torch.float32): _ = torch.ones(10) diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 5b4666fcb28..cb4212af784 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -230,9 +230,9 @@ class autocast: raise ValueError( f"Expected `device_type` of type `str`, got: `{type(device_type)}`" ) - if dtype is None: - dtype = torch.get_autocast_dtype(device_type) - self.fast_dtype = dtype + self.fast_dtype = ( + torch.get_autocast_dtype(device_type) if dtype is None else dtype + ) if torch._jit_internal.is_scripting(): self._enabled = enabled self.device = device_type @@ -243,6 +243,9 @@ class autocast: raise RuntimeError( f"User specified an unsupported autocast device_type '{self.device}'" ) + + device_supported_dtypes = [torch.bfloat16, torch.float16] + self.custom_backend_name = torch._C._get_privateuse1_backend_name() if self.device == self.custom_backend_name: necessary_funcs = [ @@ -259,111 +262,56 @@ class autocast: assert hasattr(self.custom_device_mod, func), ( message + f"But the func `{func}` is missing. \n" ) + device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype() - self._cache_enabled = torch.is_autocast_cache_enabled() - if ( - enabled - and self.device == "cuda" - and torch.cuda.amp.common.amp_definitely_not_available() - ): - warnings.warn( - "User provided device_type of 'cuda', but CUDA is not available. Disabling", - stacklevel=2, - ) - enabled = False - if cache_enabled is not None: - self._cache_enabled = cache_enabled + self._cache_enabled = ( + torch.is_autocast_cache_enabled() + if cache_enabled is None + else cache_enabled + ) - if self.device == "cpu": - supported_dtype = [torch.bfloat16, torch.float16] - if self.fast_dtype not in supported_dtype and enabled: - error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n" - error_message += "CPU Autocast only supports dtype of " - error_message += ( - ", ".join(str(dtype) for dtype in supported_dtype) + " currently." - ) - warnings.warn(error_message, stacklevel=2) - enabled = False - elif self.device == "mtia": - supported_dtype = [torch.bfloat16, torch.float16] - if self.fast_dtype not in supported_dtype: - error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n" - error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." - warnings.warn(error_message, stacklevel=2) - enabled = False - elif self.device == "maia": - supported_dtype = [torch.bfloat16, torch.float16] - if self.fast_dtype not in supported_dtype: - error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n" - error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." - warnings.warn(error_message, stacklevel=2) - enabled = False - elif self.device == "xpu": - supported_dtype = [torch.bfloat16, torch.float16] - if self.fast_dtype not in supported_dtype: - error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n" - error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." - warnings.warn(error_message, stacklevel=2) - enabled = False - elif self.device == "ipu": - supported_dtypes = [torch.bfloat16, torch.float16] - if self.fast_dtype not in supported_dtypes: - error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n" - error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." - warnings.warn(error_message, stacklevel=2) - enabled = False - elif self.device == "hpu": - supported_dtype = [torch.bfloat16, torch.float16] - if self.fast_dtype not in supported_dtype: - error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n" - error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." - warnings.warn(error_message, stacklevel=2) - enabled = False - elif self.device == self.custom_backend_name: - supported_dtype = self.custom_device_mod.get_amp_supported_dtype() - if self.fast_dtype not in supported_dtype: - error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. " - error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of " - error_message += ( - ", ".join(str(dtype) for dtype in supported_dtype) + " currently." - ) - warnings.warn(error_message, stacklevel=2) - enabled = False - elif self.device == "cuda": - if ( - enabled - and self.fast_dtype == torch.bfloat16 - and not torch.cuda.is_bf16_supported() - ): - raise RuntimeError( - "Current CUDA Device does not support bfloat16. Please switch dtype to float16." - ) - elif self.device == "mps": - supported_dtype = [torch.bfloat16, torch.float16] - if self.fast_dtype not in supported_dtype: + device_name = ( + self.device + if self.device == self.custom_backend_name + else self.device.upper() + ) + if enabled: + # Special case for CUDA AMP and bfloat16 support + if self.device == "cuda": + if torch.cuda.amp.common.amp_definitely_not_available(): + warnings.warn( + "CUDA is not available or torch_xla is imported. Disabling autocast.", + stacklevel=2, + ) + enabled = False + elif ( + self.fast_dtype == torch.bfloat16 + and not torch.cuda.is_bf16_supported() + ): + raise RuntimeError( + "Current CUDA Device does not support bfloat16. Please switch dtype to float16." + ) + elif self.fast_dtype not in device_supported_dtypes: error_message = ( - "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n" - "MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently." + f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n" + f"{device_name} Autocast only supports dtypes of " + + ", ".join(map(str, device_supported_dtypes)) + + " currently." ) warnings.warn(error_message, stacklevel=2) enabled = False - elif self.fast_dtype == torch.bfloat16: - if not torch.backends.mps.is_macos_or_newer(14, 0): + # Special case for MPS bfloat16 support on macOS < 14 + if ( + self.device == "mps" + and self.fast_dtype == torch.bfloat16 + and not torch.backends.mps.is_macos_or_newer(14, 0) + ): error_message = ( "In MPS autocast, but the target dtype torch.bfloat16 is not supported " "on macOS versions below 14. Disabling autocast." ) warnings.warn(error_message, stacklevel=2) enabled = False - elif self.device == "xla": - supported_dtype = [torch.float16, torch.bfloat16] - if self.fast_dtype not in supported_dtype: - error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n" - error_message += ( - "XLA Autocast only supports dtype of torch.bfloat16 currently." - ) - warnings.warn(error_message, stacklevel=2) - enabled = False self._enabled = enabled def __enter__(self):