diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index ceb6468da73..188c1998693 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1123,14 +1123,14 @@ def get_device_tflops(dtype): # Triton API change in https://github.com/openai/triton/pull/2293 from triton.testing import nvsmi - cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + sm_clock = nvsmi(["clocks.max.sm"])[0] if dtype in (torch.float16, torch.bfloat16): - return get_max_tensorcore_tflops(dtype, cur_sm_clock) + return get_max_tensorcore_tflops(dtype, sm_clock) if torch.backends.cuda.matmul.allow_tf32: - return get_max_tensorcore_tflops(torch.float32, cur_sm_clock) + return get_max_tensorcore_tflops(torch.float32, sm_clock) else: - return get_max_simd_tflops(torch.float32, cur_sm_clock) + return get_max_simd_tflops(torch.float32, sm_clock) else: if dtype in (torch.float16, torch.bfloat16): return get_max_tensorcore_tflops(dtype)