[cutlass backend] cache_clear algorithm select cache on fresh inductor cache (#147590)

Differential Revision: [D69959917](https://our.internmc.facebook.com/intern/diff/D69959917/)

AlgorithmSelectorCache is a cache. The expectation is that when we force disable cache + clear inductor caches, it would be clear. However that is not the case.

The reason why this is a problem can be seen by following this repro:
What we will see is
```
SingleProcess AUTOTUNE benchmarking takes 6.2202 seconds and 46.0568 seconds precompiling for 36 choices
SingleProcess AUTOTUNE benchmarking takes 492.3141 seconds and 0.0010 seconds precompiling for 36 choices
```

The root cause is, while precompiling is skipped, due to it being cache, autotuning isn't skipped since we force disable it.

repro:
```
import logging
import os

os.environ["TORCH_LOGS"] = "+output_code,+benchmarking,+inductor"

import torch

import torch._inductor.config
from torch._inductor.utils import clear_inductor_caches

torch._inductor.config.max_autotune = True
torch._inductor.config.force_disable_caches = True
torch._inductor.config.autotune_num_choices_displayed = None
torch._inductor.config.max_autotune_gemm_backends = "CUTLASS"
torch._inductor.config.autotune_fallback_to_aten = False
torch._inductor.config.cuda.cutlass_instantiation_level = "0001"

def main():
    M, N, K = 2048, 2048, 2048
    dtype = torch.bfloat16
    A = torch.randn(M, K, device="cuda", dtype=dtype)
    B = torch.randn(K, N, device="cuda", dtype=dtype)
    for _ in range(2):
        torch._dynamo.reset()
        clear_inductor_caches()
        compiled_model = torch.compile(torch.mm, fullgraph=True)
        _ = compiled_model(A, B)

    print("done")

if __name__ == "__main__":
    main()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147590
Approved by: https://github.com/eellison, https://github.com/chenyang78
This commit is contained in:
henrylhtsang 2025-02-26 11:34:42 -08:00 committed by PyTorch MergeBot
parent 97ebccaa91
commit 84c89a4527

View File

@ -27,6 +27,7 @@ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncComp
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state
from torch._inductor.utils import clear_on_fresh_inductor_cache
from torch.utils._filelock import FileLock
from torch.utils._ordered_set import OrderedSet
@ -1640,6 +1641,12 @@ class AlgorithmSelectorCache(PersistentCache):
]
] = []
clear_on_fresh_inductor_cache(self)
def cache_clear(self) -> None:
self.precompile_cache.clear()
self.feedback_saver_fns.clear()
def __call__(
self,
name,