mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Cutlass] Add test verifying number of precompiles (#147477)
As title Pull Request resolved: https://github.com/pytorch/pytorch/pull/147477 Approved by: https://github.com/henrylhtsang
This commit is contained in:
parent
5f5b44f6bf
commit
7185ca8348
|
|
@ -242,6 +242,57 @@ class TestCutlassBackend(TestCase):
|
|||
2,
|
||||
).run(codes[0])
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_number_mm_precompiles(self):
|
||||
torch._dynamo.utils.counters.clear()
|
||||
max_autotune_gemm_backends = "CUTLASS"
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, a, b, c):
|
||||
ab = a @ b
|
||||
return ab
|
||||
|
||||
model = MyModel()
|
||||
a = torch.randn(128, 16).cuda().half()
|
||||
b = torch.randn(16, 128).cuda().half()
|
||||
c = torch.randn(16, 512).cuda().half()
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
"autotune_fallback_to_aten": False,
|
||||
"cuda.cutlass_max_profiling_swizzle_options": [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
], # guarantees > 1 choices
|
||||
}
|
||||
):
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
|
||||
compiled = torch.compile(model, dynamic=True)
|
||||
expected = model(a, b, c)
|
||||
actual, codes = run_and_get_code(compiled, a, b, c)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
FileCheck().check_count(
|
||||
"cuda_fused_0.cuda_fused_0",
|
||||
1,
|
||||
).run(codes[0])
|
||||
# Verifies expected number of precompilations
|
||||
self.assertEqual(
|
||||
torch._dynamo.utils.counters["inductor"][
|
||||
"select_algorithm_num_precompiles"
|
||||
],
|
||||
1,
|
||||
)
|
||||
|
||||
# NOTE: right now tuned_mm doesn't support cutlass 2x, which is used by A100
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@parametrize("dynamic", (False, True))
|
||||
|
|
|
|||
|
|
@ -1824,6 +1824,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
"Exception %s for benchmark choice %s", e, futures[future]
|
||||
)
|
||||
else:
|
||||
counters["inductor"]["select_algorithm_num_precompiles"] += 1
|
||||
log.info(
|
||||
"Precompiling benchmark choice %s took %.02fs",
|
||||
futures[future],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user