[Cutlass] Add test verifying number of precompiles (#147477)

As title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147477
Approved by: https://github.com/henrylhtsang
This commit is contained in:
Michael Lazos 2025-02-20 04:47:55 +00:00 committed by PyTorch MergeBot
parent 5f5b44f6bf
commit 7185ca8348
2 changed files with 52 additions and 0 deletions

View File

@ -242,6 +242,57 @@ class TestCutlassBackend(TestCase):
2,
).run(codes[0])
@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_number_mm_precompiles(self):
torch._dynamo.utils.counters.clear()
max_autotune_gemm_backends = "CUTLASS"
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c):
ab = a @ b
return ab
model = MyModel()
a = torch.randn(128, 16).cuda().half()
b = torch.randn(16, 128).cuda().half()
c = torch.randn(16, 512).cuda().half()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 1,
"autotune_fallback_to_aten": False,
"cuda.cutlass_max_profiling_swizzle_options": [
1,
2,
4,
], # guarantees > 1 choices
}
):
from torch._inductor.utils import run_and_get_code
compiled = torch.compile(model, dynamic=True)
expected = model(a, b, c)
actual, codes = run_and_get_code(compiled, a, b, c)
torch.testing.assert_close(actual, expected)
FileCheck().check_count(
"cuda_fused_0.cuda_fused_0",
1,
).run(codes[0])
# Verifies expected number of precompilations
self.assertEqual(
torch._dynamo.utils.counters["inductor"][
"select_algorithm_num_precompiles"
],
1,
)
# NOTE: right now tuned_mm doesn't support cutlass 2x, which is used by A100
@unittest.skipIf(not SM90OrLater, "need sm_90")
@parametrize("dynamic", (False, True))

View File

@ -1824,6 +1824,7 @@ class AlgorithmSelectorCache(PersistentCache):
"Exception %s for benchmark choice %s", e, futures[future]
)
else:
counters["inductor"]["select_algorithm_num_precompiles"] += 1
log.info(
"Precompiling benchmark choice %s took %.02fs",
futures[future],