diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 10701d0d8b2..d22d67cecff 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2178,7 +2178,10 @@ def get_device_tflops(dtype: torch.dtype) -> float: from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops - from torch.testing._internal.common_cuda import SM80OrLater + SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + 8, + 0, + ) assert dtype in (torch.float16, torch.bfloat16, torch.float32)