mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Use max sm clock when calculating device tflops (#116754)
See openai/triton#2801 Current SM clocks may fluctuate at runtime and change the result of `get_device_tflops`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/116754 Approved by: https://github.com/lezcano
This commit is contained in:
parent
6793b99107
commit
39f8853313
|
|
@ -1123,14 +1123,14 @@ def get_device_tflops(dtype):
|
|||
# Triton API change in https://github.com/openai/triton/pull/2293
|
||||
from triton.testing import nvsmi
|
||||
|
||||
cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
|
||||
sm_clock = nvsmi(["clocks.max.sm"])[0]
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
return get_max_tensorcore_tflops(dtype, cur_sm_clock)
|
||||
return get_max_tensorcore_tflops(dtype, sm_clock)
|
||||
|
||||
if torch.backends.cuda.matmul.allow_tf32:
|
||||
return get_max_tensorcore_tflops(torch.float32, cur_sm_clock)
|
||||
return get_max_tensorcore_tflops(torch.float32, sm_clock)
|
||||
else:
|
||||
return get_max_simd_tflops(torch.float32, cur_sm_clock)
|
||||
return get_max_simd_tflops(torch.float32, sm_clock)
|
||||
else:
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
return get_max_tensorcore_tflops(dtype)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user