mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR refactors the autocast context manager in autocast_mode.py to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping device_supported_dtypes is used to associate device types with their supported dtypes, and the validation logic is unified. **The former PR #163446 was merged but reverted due to failed CI test on `openreg` related tests.** This RR additionally slightly modified some test assertions for passing the CI tests. CI failed due to assertion for the exactly same error message. For example: ``` File "/var/lib/jenkins/workspace/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_autocast.py", line 9, in test_autocast_with_unsupported_type with self.assertWarnsRegex( AssertionError: "In openreg autocast, but the target dtype torch.float32 is not supported." does not match "In openreg autocast, but the target dtype is not supported. Disabling autocast." ``` Sorry for the inconvenience again. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165221 Approved by: https://github.com/albanD |
||
|---|---|---|
| .. | ||
| libtorch_agnostic_extension | ||
| no_python_abi_suffix_test | ||
| open_registration_extension/torch_openreg | ||
| python_agnostic_extension | ||
| self_compiler_include_dirs_test | ||
| torch_stable_test_extension | ||
| torch_test_cpp_extension | ||
| cpp_c10d_extension.cpp | ||
| cpp_c10d_extension.hpp | ||
| cpp_frontend_extension.cpp | ||
| cublas_extension.cpp | ||
| cuda_dlink_extension_add.cu | ||
| cuda_dlink_extension_add.cuh | ||
| cuda_dlink_extension_kernel.cu | ||
| cuda_dlink_extension.cpp | ||
| cuda_extension_kernel.cu | ||
| cuda_extension_kernel2.cu | ||
| cuda_extension.cpp | ||
| cuda_extension.cu | ||
| cudnn_extension.cpp | ||
| cusolver_extension.cpp | ||
| dangling_impl_extension.cpp | ||
| doubler.h | ||
| extension.cpp | ||
| identity.cpp | ||
| jit_extension.cpp | ||
| jit_extension2.cpp | ||
| maia_extension.cpp | ||
| mps_extension.mm | ||
| mtia_extension.cpp | ||
| open_registration_extension.cpp | ||
| rng_extension.cpp | ||
| setup.py | ||
| torch_library.cu | ||
| xpu_extension.sycl | ||