[Inductor] Add Subgraph as a Autotuning Choice (#150653)

Add the option for providing a Subgraph as an autotuning choice in Inductor. This is crucial for implementing the split-k optimization for GEMMs by decomposing a mm -> bmm. https://github.com/pytorch/pytorch/pull/150654 uses these changes to add decomposeK as a default autotuning choice for aten.mm in Inductor.

Using https://github.com/pytorch/pytorch/pull/150654 and a simple script:

```
import torch

def f(a, b):
    return torch.matmul(a, b)

def decompose_func(a_in, b_in):
    M, K = a_in.shape
    K, N = b_in.shape

    # TODO: Ideally we want to autotune over this parameter
    kPartitions = 256
    assert K % kPartitions == 0, "K must be divisible by Kmini"
    B = K // kPartitions

    a_reshaped = a_in.reshape(M, B, kPartitions).transpose(
        0, 1
      )  # Shape: (B, M, kPartitions)
    b_reshaped = b_in.reshape(B, kPartitions, N)  # Shape: (B, kPartitions, N)
    result = torch.bmm(a_reshaped, b_reshaped)  # Shape: (B, M, N)
    return result.sum(dim=0).to(torch.float16)  # Sum over B dimension, Shape: (M, N)

for k in [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768]:
    a = torch.randn(32, k, dtype=torch.float16, device="cuda", requires_grad=True)
    b = torch.randn(k, 32, dtype=torch.float16, device="cuda", requires_grad=True)

    compiled_res = torch.compile(f, dynamic=False)(a, b)
    decompose_res = decompose_func(a, b)

    print(f"Compiled mm result close to aten: {torch.allclose(f(a, b), compiled_res, atol=1e-5, rtol=0.5)}")
    print(f"Compiled mm result close to decompose: {torch.allclose(decompose_res, compiled_res, atol=1e-5, rtol=0.5)}")
```

we are able to autotune the decomposeK optimization to aten and the traditional Triton templates in Inductor. DecomposeK is faster than aten by about ~10% on average and > 4x speedup over the best Triton templates on an H100 machine, e.g.:

```
AUTOTUNE mm(32x28672, 28672x32)
  decompose_k_mm 0.0126 ms 100.0%
  mm 0.0144 ms 87.5%
  triton_mm_69 0.0579 ms 21.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_75 0.0677 ms 18.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
  triton_mm_76 0.0850 ms 14.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_68 0.1444 ms 8.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_72 0.1546 ms 8.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_74 0.1819 ms 6.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
  triton_mm_67 0.1917 ms 6.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_73 0.2766 ms 4.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
```

https://pastebin.com/g3FMaauT is the generated code from Inductor containing the subgraph decomposition for aten.mm.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150653
Approved by: https://github.com/eellison
This commit is contained in:
PaulZhang12 2025-04-11 07:43:35 -07:00 committed by PyTorch MergeBot
parent ad5e9065ac
commit 83ae61fd8e
3 changed files with 318 additions and 0 deletions

View File

@ -0,0 +1,118 @@
# Owner(s): ["module: inductor"]
import functools
import torch
from torch._dispatch.python import enable_python_dispatcher
from torch._inductor.codegen.subgraph import SubgraphTemplate
from torch._inductor.decomposition import select_decomp_table
from torch._inductor.ir import Buffer, FixedLayout
from torch._inductor.lowering import register_lowering
from torch._inductor.select_algorithm import (
AlgorithmSelectorCache,
autotune_select_algorithm,
)
from torch._inductor.test_case import run_tests, TestCase
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
class TestSubgraphChoice(TestCase):
def setUp(self):
super().setUp()
def _create_buffer(self, name, shape, dtype):
return Buffer(
name=name,
layout=FixedLayout(torch.device(f"{GPU_TYPE}:0"), dtype=dtype, size=shape),
)
def test_subgraph_decompose_k(self):
from torch._inductor.kernel.mm import aten_mm
from torch._inductor.kernel.mm_common import mm_args
@torch.library.custom_op("mylib::matmul_decompose", mutates_args={})
def matmul_decompose(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a @ b
@matmul_decompose.register_fake
def _(a, b):
return a @ b
def decomposeK(a, b, kPartitions):
m = a.shape[0]
n = b.shape[1]
k = a.shape[1]
B = k // kPartitions
a_reshaped = torch.permute(a.reshape(m, B, kPartitions), (1, 0, 2))
b_reshaped = b.reshape(B, kPartitions, n)
result = torch.bmm(a_reshaped, b_reshaped)
result_fp32 = result.to(torch.float32)
reduced_buf = torch.sum(result_fp32, 0)
return reduced_buf.to(a.dtype)
mat1_shape, mat2_shape = (32, 4096), (4096, 32)
@register_lowering(torch.ops.mylib.matmul_decompose)
def _(a, b):
_, _, _, layout, mat1, mat2 = mm_args(a, b)
choices = [aten_mm.bind((mat1, mat2), layout)]
# TODO (PaulZhang12): Once decomposeK lands in Inductor, move this
kPartitions = 256
with enable_python_dispatcher():
decompositions = select_decomp_table()
decompose_k_subgraph_template = SubgraphTemplate(
name="decompose_k_mm",
make_fx_graph=make_fx(
functools.partial(decomposeK, kPartitions=kPartitions),
decompositions,
tracing_mode="real",
),
)
mat1_tensor, mat2_tensor = (
AlgorithmSelectorCache.benchmark_example_value(mat1),
AlgorithmSelectorCache.benchmark_example_value(mat2),
)
decompose_k_subgraph_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
layout=layout,
example_inputs=[mat1_tensor, mat2_tensor],
)
# Test benchmarking against aten
autotune_select_algorithm("test_subgraph_choice", choices, [a, b], layout)
# Only return decomposeK case for codegen
choices = [choices[1]]
return autotune_select_algorithm(
"test_subgraph_choice", choices, [a, b], layout
)
a_in = torch.randn(
mat1_shape, dtype=torch.float16, device=torch.device(f"{GPU_TYPE}:0")
)
b_in = torch.randn(
mat2_shape, dtype=torch.float16, device=torch.device(f"{GPU_TYPE}:0")
)
def func(mat1, mat2):
return torch.ops.mylib.matmul_decompose(mat1, mat2)
compiled_func = torch.compile(func, mode="max-autotune", dynamic=False)
res = compiled_func(a_in, b_in)
# Check same results of compiled result and regular torch.mm
# Relax precision as decomposeK does first accumulation in fp16
torch.testing.assert_close(res, a_in @ b_in, atol=1e-1, rtol=1e-1)
if __name__ == "__main__":
# Set env to make it work in CI.
if HAS_GPU and HAS_CPU:
run_tests()

View File

@ -0,0 +1,157 @@
import logging
from typing import Any, Callable
import torch
from torch._inductor import ir
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import Buffer, Layout
from torch._inductor.runtime.benchmarking import benchmarker
from torch._inductor.virtualized import V
log = logging.getLogger(__name__)
class SubgraphChoiceCaller(ir.ChoiceCaller):
"""
Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary
GraphModule. Compiles the Subgraph down to a module for benchmarking.
"""
def __init__(
self,
name: str,
input_nodes: list[Buffer],
layout: Layout,
description: str,
gm: torch.fx.GraphModule,
example_inputs: list[Any],
) -> None:
super().__init__(name, input_nodes, layout, description)
self.gm = gm
self.example_inputs = example_inputs
def __str__(self) -> str:
return f"SubgraphCaller({self.name})"
def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
# Codegen Subgraph for benchmarking
# Need GraphLowering instead of SubgraphLowering to generate
# fully callable module
import torch._inductor.config as inductor_config
from torch._inductor.graph import GraphLowering
bm_graph_lowering = GraphLowering(
gm=self.gm,
example_inputs=self.example_inputs,
shape_env=V.graph._shape_env,
cpp_wrapper=V.graph.cpp_wrapper,
aot_mode=V.graph.aot_mode,
extern_node_serializer=V.graph.extern_node_serializer,
is_inference=V.graph.is_inference,
is_backward=V.graph.is_backward,
name=f"benchmark_{self.name}",
)
with V.set_graph_handler(bm_graph_lowering):
# Don't bother autotuning on Triton here
with inductor_config.patch(
max_autotune=False,
max_autotune_gemm=False,
max_autotune_gemm_backends="ATEN",
):
bm_graph_lowering.run(*self.example_inputs)
mod = bm_graph_lowering.compile_to_module()
bm_func = mod.call
bm_func([*args])
return benchmarker.benchmark_gpu(lambda: bm_func([*args]))
def hash_key(self) -> str:
return "-".join(
[
self.name,
*[
str(arg.shape)
for arg in self.example_inputs
if isinstance(arg, torch.Tensor)
],
str(self.gm.graph),
]
)
def output_node(self) -> ir.TensorBox:
return ir.TensorBox.create(
ir.SubgraphBuffer(
layout=self.layout,
input_nodes=self.input_nodes,
gm=self.gm,
example_inputs=self.example_inputs,
subgraph_name=self.name,
)
)
def info_dict(self) -> dict[str, Any]:
"""Information returned here is logged to the autotune log file when that is enabled."""
return {
"backend": "subgraph",
"kernel_name": self.name,
}
def autoheuristic_id(self) -> str:
return f"subgraph_{self.name}"
class SubgraphTemplate(KernelTemplate):
"""
A template for subgraph evaluation to be used in autotuning.
This class allows creating customized subgraphs that can be appended
as choices during the autotuning process, enabling the selection of
optimal implementations for complex operations.
"""
def __init__(
self,
name: str,
make_fx_graph: Callable[..., Any],
):
"""
Initialize a subgraph template.
Args:
name: The name of this template
graph: The FX graph
"""
self.name = name
self.make_fx_graph = make_fx_graph
def generate( # type: ignore[override]
self,
input_nodes: list[Buffer],
layout: Layout,
example_inputs: list[Any],
**kwargs: Any,
) -> SubgraphChoiceCaller:
"""
Generate a SubgraphChoiceCaller instance for autotuning.
Args:
input_nodes: List of input nodes to the subgraph
layout: Memory layout information for the output
example_inputs: Example tensor inputs used to trace and benchmark the subgraph
**kwargs: Additional keyword arguments
Returns:
SubgraphChoiceCaller: A callable object that can be used for autotuning
"""
gm = self.make_fx_graph(*example_inputs)
return SubgraphChoiceCaller(
name=self.name,
input_nodes=input_nodes,
layout=layout,
description="",
gm=gm,
example_inputs=example_inputs,
)

View File

@ -5999,6 +5999,49 @@ class TMADescriptor(ExternKernel):
wrapper.generate_tma_descriptor(self) wrapper.generate_tma_descriptor(self)
class SubgraphBuffer(ExternKernel):
def __init__(
self,
layout: Layout,
input_nodes: list[Buffer],
gm: torch.fx.GraphModule,
example_inputs: list[Any],
subgraph_name: str,
):
super().__init__(None, layout, input_nodes)
self.gm = gm
self.example_inputs = example_inputs
self.name = V.graph.register_buffer(self)
V.graph.register_operation(self)
self.subgraph = V.graph.make_subgraph(
self.gm, self.example_inputs, subgraph_name
)
import torch._inductor.config as inductor_config
with V.set_graph_handler(self.subgraph):
# Don't bother autotuning on Triton here
with inductor_config.patch( # type: ignore[no-untyped-def]
max_autotune=False,
max_autotune_gemm=False,
max_autotune_gemm_backends="ATEN",
):
self.subgraph.run(*self.example_inputs)
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
class CodegenGraph:
def __init__(self, graph: GraphLowering):
self.graph = graph
self.name = graph.name
wrapper.codegen_subgraph(
CodegenGraph(self.subgraph),
[*[buffer.get_name() for buffer in self.inputs]],
[self.name],
)
class UserDefinedTritonKernel(ExternKernel): class UserDefinedTritonKernel(ExternKernel):
def get_kernel_and_metadata(self): # type: ignore[no-untyped-def] def get_kernel_and_metadata(self): # type: ignore[no-untyped-def]
from triton.runtime.autotuner import Autotuner from triton.runtime.autotuner import Autotuner