mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cutlass backend] cache_clear algorithm select cache on fresh inductor cache (#147590)
Differential Revision: [D69959917](https://our.internmc.facebook.com/intern/diff/D69959917/) AlgorithmSelectorCache is a cache. The expectation is that when we force disable cache + clear inductor caches, it would be clear. However that is not the case. The reason why this is a problem can be seen by following this repro: What we will see is ``` SingleProcess AUTOTUNE benchmarking takes 6.2202 seconds and 46.0568 seconds precompiling for 36 choices SingleProcess AUTOTUNE benchmarking takes 492.3141 seconds and 0.0010 seconds precompiling for 36 choices ``` The root cause is, while precompiling is skipped, due to it being cache, autotuning isn't skipped since we force disable it. repro: ``` import logging import os os.environ["TORCH_LOGS"] = "+output_code,+benchmarking,+inductor" import torch import torch._inductor.config from torch._inductor.utils import clear_inductor_caches torch._inductor.config.max_autotune = True torch._inductor.config.force_disable_caches = True torch._inductor.config.autotune_num_choices_displayed = None torch._inductor.config.max_autotune_gemm_backends = "CUTLASS" torch._inductor.config.autotune_fallback_to_aten = False torch._inductor.config.cuda.cutlass_instantiation_level = "0001" def main(): M, N, K = 2048, 2048, 2048 dtype = torch.bfloat16 A = torch.randn(M, K, device="cuda", dtype=dtype) B = torch.randn(K, N, device="cuda", dtype=dtype) for _ in range(2): torch._dynamo.reset() clear_inductor_caches() compiled_model = torch.compile(torch.mm, fullgraph=True) _ = compiled_model(A, B) print("done") if __name__ == "__main__": main() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/147590 Approved by: https://github.com/eellison, https://github.com/chenyang78
This commit is contained in:
parent
97ebccaa91
commit
84c89a4527
|
|
@ -27,6 +27,7 @@ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncComp
|
|||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state
|
||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
||||
from torch.utils._filelock import FileLock
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
|
@ -1640,6 +1641,12 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
]
|
||||
] = []
|
||||
|
||||
clear_on_fresh_inductor_cache(self)
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
self.precompile_cache.clear()
|
||||
self.feedback_saver_fns.clear()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
name,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user