mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
83ae61fd8e
1 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
83ae61fd8e |
[Inductor] Add Subgraph as a Autotuning Choice (#150653)
Add the option for providing a Subgraph as an autotuning choice in Inductor. This is crucial for implementing the split-k optimization for GEMMs by decomposing a mm -> bmm. https://github.com/pytorch/pytorch/pull/150654 uses these changes to add decomposeK as a default autotuning choice for aten.mm in Inductor. Using https://github.com/pytorch/pytorch/pull/150654 and a simple script: ``` import torch def f(a, b): return torch.matmul(a, b) def decompose_func(a_in, b_in): M, K = a_in.shape K, N = b_in.shape # TODO: Ideally we want to autotune over this parameter kPartitions = 256 assert K % kPartitions == 0, "K must be divisible by Kmini" B = K // kPartitions a_reshaped = a_in.reshape(M, B, kPartitions).transpose( 0, 1 ) # Shape: (B, M, kPartitions) b_reshaped = b_in.reshape(B, kPartitions, N) # Shape: (B, kPartitions, N) result = torch.bmm(a_reshaped, b_reshaped) # Shape: (B, M, N) return result.sum(dim=0).to(torch.float16) # Sum over B dimension, Shape: (M, N) for k in [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768]: a = torch.randn(32, k, dtype=torch.float16, device="cuda", requires_grad=True) b = torch.randn(k, 32, dtype=torch.float16, device="cuda", requires_grad=True) compiled_res = torch.compile(f, dynamic=False)(a, b) decompose_res = decompose_func(a, b) print(f"Compiled mm result close to aten: {torch.allclose(f(a, b), compiled_res, atol=1e-5, rtol=0.5)}") print(f"Compiled mm result close to decompose: {torch.allclose(decompose_res, compiled_res, atol=1e-5, rtol=0.5)}") ``` we are able to autotune the decomposeK optimization to aten and the traditional Triton templates in Inductor. DecomposeK is faster than aten by about ~10% on average and > 4x speedup over the best Triton templates on an H100 machine, e.g.: ``` AUTOTUNE mm(32x28672, 28672x32) decompose_k_mm 0.0126 ms 100.0% mm 0.0144 ms 87.5% triton_mm_69 0.0579 ms 21.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_75 0.0677 ms 18.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4 triton_mm_76 0.0850 ms 14.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_68 0.1444 ms 8.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_72 0.1546 ms 8.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_74 0.1819 ms 6.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4 triton_mm_67 0.1917 ms 6.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4 triton_mm_73 0.2766 ms 4.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 ``` https://pastebin.com/g3FMaauT is the generated code from Inductor containing the subgraph decomposition for aten.mm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150653 Approved by: https://github.com/eellison |