mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AMP][Refactor] Autocast dtype handling to simplify device-specific c… (#165221)
This PR refactors the autocast context manager in autocast_mode.py to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping device_supported_dtypes is used to associate device types with their supported dtypes, and the validation logic is unified. **The former PR #163446 was merged but reverted due to failed CI test on `openreg` related tests.** This RR additionally slightly modified some test assertions for passing the CI tests. CI failed due to assertion for the exactly same error message. For example: ``` File "/var/lib/jenkins/workspace/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py", line 9, in test_autocast_with_unsupported_type with self.assertWarnsRegex( AssertionError: "In openreg autocast, but the target dtype torch.float32 is not supported." does not match "In openreg autocast, but the target dtype is not supported. Disabling autocast." ``` Sorry for the inconvenience again. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165221 Approved by: https://github.com/albanD
This commit is contained in:
parent
5d0b3e28dc
commit
13413b3b07
|
|
@ -8,7 +8,8 @@ class TestAutocast(TestCase):
|
||||||
def test_autocast_with_unsupported_type(self):
|
def test_autocast_with_unsupported_type(self):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"In openreg autocast, but the target dtype torch.float32 is not supported.",
|
"In openreg autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||||
|
"openreg Autocast only supports dtypes of torch.float16, torch.bfloat16 currently.",
|
||||||
):
|
):
|
||||||
with torch.autocast(device_type="openreg", dtype=torch.float32):
|
with torch.autocast(device_type="openreg", dtype=torch.float32):
|
||||||
_ = torch.ones(10)
|
_ = torch.ones(10)
|
||||||
|
|
|
||||||
|
|
@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
|
||||||
def test_mps_autocast_error_message(self):
|
def test_mps_autocast_error_message(self):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
|
"MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
|
||||||
):
|
):
|
||||||
with torch.autocast(device_type="mps", dtype=torch.float32):
|
with torch.autocast(device_type="mps", dtype=torch.float32):
|
||||||
_ = torch.ones(10)
|
_ = torch.ones(10)
|
||||||
|
|
|
||||||
|
|
@ -230,9 +230,9 @@ class autocast:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
|
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
|
||||||
)
|
)
|
||||||
if dtype is None:
|
self.fast_dtype = (
|
||||||
dtype = torch.get_autocast_dtype(device_type)
|
torch.get_autocast_dtype(device_type) if dtype is None else dtype
|
||||||
self.fast_dtype = dtype
|
)
|
||||||
if torch._jit_internal.is_scripting():
|
if torch._jit_internal.is_scripting():
|
||||||
self._enabled = enabled
|
self._enabled = enabled
|
||||||
self.device = device_type
|
self.device = device_type
|
||||||
|
|
@ -243,6 +243,9 @@ class autocast:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"User specified an unsupported autocast device_type '{self.device}'"
|
f"User specified an unsupported autocast device_type '{self.device}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device_supported_dtypes = [torch.bfloat16, torch.float16]
|
||||||
|
|
||||||
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
|
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||||
if self.device == self.custom_backend_name:
|
if self.device == self.custom_backend_name:
|
||||||
necessary_funcs = [
|
necessary_funcs = [
|
||||||
|
|
@ -259,111 +262,56 @@ class autocast:
|
||||||
assert hasattr(self.custom_device_mod, func), (
|
assert hasattr(self.custom_device_mod, func), (
|
||||||
message + f"But the func `{func}` is missing. \n"
|
message + f"But the func `{func}` is missing. \n"
|
||||||
)
|
)
|
||||||
|
device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype()
|
||||||
|
|
||||||
self._cache_enabled = torch.is_autocast_cache_enabled()
|
self._cache_enabled = (
|
||||||
if (
|
torch.is_autocast_cache_enabled()
|
||||||
enabled
|
if cache_enabled is None
|
||||||
and self.device == "cuda"
|
else cache_enabled
|
||||||
and torch.cuda.amp.common.amp_definitely_not_available()
|
)
|
||||||
):
|
|
||||||
|
device_name = (
|
||||||
|
self.device
|
||||||
|
if self.device == self.custom_backend_name
|
||||||
|
else self.device.upper()
|
||||||
|
)
|
||||||
|
if enabled:
|
||||||
|
# Special case for CUDA AMP and bfloat16 support
|
||||||
|
if self.device == "cuda":
|
||||||
|
if torch.cuda.amp.common.amp_definitely_not_available():
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"User provided device_type of 'cuda', but CUDA is not available. Disabling",
|
"CUDA is not available or torch_xla is imported. Disabling autocast.",
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
enabled = False
|
enabled = False
|
||||||
if cache_enabled is not None:
|
elif (
|
||||||
self._cache_enabled = cache_enabled
|
self.fast_dtype == torch.bfloat16
|
||||||
|
|
||||||
if self.device == "cpu":
|
|
||||||
supported_dtype = [torch.bfloat16, torch.float16]
|
|
||||||
if self.fast_dtype not in supported_dtype and enabled:
|
|
||||||
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
|
||||||
error_message += "CPU Autocast only supports dtype of "
|
|
||||||
error_message += (
|
|
||||||
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
|
||||||
)
|
|
||||||
warnings.warn(error_message, stacklevel=2)
|
|
||||||
enabled = False
|
|
||||||
elif self.device == "mtia":
|
|
||||||
supported_dtype = [torch.bfloat16, torch.float16]
|
|
||||||
if self.fast_dtype not in supported_dtype:
|
|
||||||
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
|
||||||
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
|
||||||
warnings.warn(error_message, stacklevel=2)
|
|
||||||
enabled = False
|
|
||||||
elif self.device == "maia":
|
|
||||||
supported_dtype = [torch.bfloat16, torch.float16]
|
|
||||||
if self.fast_dtype not in supported_dtype:
|
|
||||||
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
|
||||||
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
|
||||||
warnings.warn(error_message, stacklevel=2)
|
|
||||||
enabled = False
|
|
||||||
elif self.device == "xpu":
|
|
||||||
supported_dtype = [torch.bfloat16, torch.float16]
|
|
||||||
if self.fast_dtype not in supported_dtype:
|
|
||||||
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
|
||||||
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
|
||||||
warnings.warn(error_message, stacklevel=2)
|
|
||||||
enabled = False
|
|
||||||
elif self.device == "ipu":
|
|
||||||
supported_dtypes = [torch.bfloat16, torch.float16]
|
|
||||||
if self.fast_dtype not in supported_dtypes:
|
|
||||||
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
|
||||||
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
|
||||||
warnings.warn(error_message, stacklevel=2)
|
|
||||||
enabled = False
|
|
||||||
elif self.device == "hpu":
|
|
||||||
supported_dtype = [torch.bfloat16, torch.float16]
|
|
||||||
if self.fast_dtype not in supported_dtype:
|
|
||||||
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
|
||||||
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
|
||||||
warnings.warn(error_message, stacklevel=2)
|
|
||||||
enabled = False
|
|
||||||
elif self.device == self.custom_backend_name:
|
|
||||||
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
|
|
||||||
if self.fast_dtype not in supported_dtype:
|
|
||||||
error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. "
|
|
||||||
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
|
|
||||||
error_message += (
|
|
||||||
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
|
||||||
)
|
|
||||||
warnings.warn(error_message, stacklevel=2)
|
|
||||||
enabled = False
|
|
||||||
elif self.device == "cuda":
|
|
||||||
if (
|
|
||||||
enabled
|
|
||||||
and self.fast_dtype == torch.bfloat16
|
|
||||||
and not torch.cuda.is_bf16_supported()
|
and not torch.cuda.is_bf16_supported()
|
||||||
):
|
):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
||||||
)
|
)
|
||||||
elif self.device == "mps":
|
elif self.fast_dtype not in device_supported_dtypes:
|
||||||
supported_dtype = [torch.bfloat16, torch.float16]
|
|
||||||
if self.fast_dtype not in supported_dtype:
|
|
||||||
error_message = (
|
error_message = (
|
||||||
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
|
f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
|
f"{device_name} Autocast only supports dtypes of "
|
||||||
|
+ ", ".join(map(str, device_supported_dtypes))
|
||||||
|
+ " currently."
|
||||||
)
|
)
|
||||||
warnings.warn(error_message, stacklevel=2)
|
warnings.warn(error_message, stacklevel=2)
|
||||||
enabled = False
|
enabled = False
|
||||||
elif self.fast_dtype == torch.bfloat16:
|
# Special case for MPS bfloat16 support on macOS < 14
|
||||||
if not torch.backends.mps.is_macos_or_newer(14, 0):
|
if (
|
||||||
|
self.device == "mps"
|
||||||
|
and self.fast_dtype == torch.bfloat16
|
||||||
|
and not torch.backends.mps.is_macos_or_newer(14, 0)
|
||||||
|
):
|
||||||
error_message = (
|
error_message = (
|
||||||
"In MPS autocast, but the target dtype torch.bfloat16 is not supported "
|
"In MPS autocast, but the target dtype torch.bfloat16 is not supported "
|
||||||
"on macOS versions below 14. Disabling autocast."
|
"on macOS versions below 14. Disabling autocast."
|
||||||
)
|
)
|
||||||
warnings.warn(error_message, stacklevel=2)
|
warnings.warn(error_message, stacklevel=2)
|
||||||
enabled = False
|
enabled = False
|
||||||
elif self.device == "xla":
|
|
||||||
supported_dtype = [torch.float16, torch.bfloat16]
|
|
||||||
if self.fast_dtype not in supported_dtype:
|
|
||||||
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
|
||||||
error_message += (
|
|
||||||
"XLA Autocast only supports dtype of torch.bfloat16 currently."
|
|
||||||
)
|
|
||||||
warnings.warn(error_message, stacklevel=2)
|
|
||||||
enabled = False
|
|
||||||
self._enabled = enabled
|
self._enabled = enabled
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user