Revert "[AMP][Refactor] Autocast dtype handling to simplify device-specific c… (#165221)"

This reverts commit 4be1e3bf92.

Reverted https://github.com/pytorch/pytorch/pull/165221 on behalf of https://github.com/clee2000 due to I think this broke test_openreg [GH job link](https://github.com/pytorch/pytorch/actions/runs/18698271058/job/53322459496) [HUD commit link](4be1e3bf92) note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/165221#issuecomment-3430012693))
This commit is contained in:
PyTorch MergeBot 2025-10-22 00:26:55 +00:00
parent 7cb467a169
commit 7773a22cdb
3 changed files with 98 additions and 47 deletions

View File

@ -8,8 +8,7 @@ class TestAutocast(TestCase):
def test_autocast_with_unsupported_type(self):
with self.assertWarnsRegex(
UserWarning,
"In openreg autocast, but the target dtype is not supported."
"openreg Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
"In openreg autocast, but the target dtype torch.float32 is not supported.",
):
with torch.autocast(device_type="openreg", dtype=torch.float32):
_ = torch.ones(10)

View File

@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
def test_mps_autocast_error_message(self):
with self.assertWarnsRegex(
UserWarning,
"MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
):
with torch.autocast(device_type="mps", dtype=torch.float32):
_ = torch.ones(10)

View File

@ -230,9 +230,9 @@ class autocast:
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
self.fast_dtype = (
torch.get_autocast_dtype(device_type) if dtype is None else dtype
)
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
self.fast_dtype = dtype
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
@ -243,9 +243,6 @@ class autocast:
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
device_supported_dtypes = [torch.bfloat16, torch.float16]
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
if self.device == self.custom_backend_name:
necessary_funcs = [
@ -262,55 +259,110 @@ class autocast:
assert hasattr(self.custom_device_mod, func), (
message + f"But the func `{func}` is missing. \n"
)
device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype()
self._cache_enabled = (
torch.is_autocast_cache_enabled()
if cache_enabled is None
else cache_enabled
)
self._cache_enabled = torch.is_autocast_cache_enabled()
if (
enabled
and self.device == "cuda"
and torch.cuda.amp.common.amp_definitely_not_available()
):
warnings.warn(
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
)
enabled = False
if cache_enabled is not None:
self._cache_enabled = cache_enabled
device_name = (
self.device
if self.device == self.custom_backend_name
else self.device.upper()
)
if enabled:
# Special case for CUDA AMP and bfloat16 support
if self.device == "cuda":
if torch.cuda.amp.common.amp_definitely_not_available():
warnings.warn(
"CUDA is not available or torch_xla is imported. Disabling autocast."
)
enabled = False
elif (
self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.fast_dtype not in device_supported_dtypes:
error_message = (
f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n"
f"{device_name} Autocast only supports dtypes of "
+ ", ".join(map(str, device_supported_dtypes))
+ " currently."
if self.device == "cpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype and enabled:
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "CPU Autocast only supports dtype of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
# Special case for MPS bfloat16 support on macOS < 14
if (
self.device == "mps"
and self.fast_dtype == torch.bfloat16
and not torch.backends.mps.is_macos_or_newer(14, 0)
):
elif self.device == "mtia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "maia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "xpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "ipu":
supported_dtypes = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtypes:
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "hpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == self.custom_backend_name:
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
if self.fast_dtype not in supported_dtype:
error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. "
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == "cuda":
if (
enabled
and self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.device == "mps":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = (
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
)
warnings.warn(error_message)
enabled = False
elif self.fast_dtype == torch.bfloat16:
if not torch.backends.mps.is_macos_or_newer(14, 0):
error_message = (
"In MPS autocast, but the target dtype torch.bfloat16 is not supported "
"on macOS versions below 14. Disabling autocast."
)
warnings.warn(error_message)
enabled = False
elif self.device == "xla":
supported_dtype = [torch.float16, torch.bfloat16]
if self.fast_dtype not in supported_dtype:
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"XLA Autocast only supports dtype of torch.bfloat16 currently."
)
warnings.warn(error_message)
enabled = False
self._enabled = enabled
def __enter__(self):