pytorch/torch/_inductor/codegen/subgraph.py
PaulZhang12 83ae61fd8e [Inductor] Add Subgraph as a Autotuning Choice (#150653)
Add the option for providing a Subgraph as an autotuning choice in Inductor. This is crucial for implementing the split-k optimization for GEMMs by decomposing a mm -> bmm. https://github.com/pytorch/pytorch/pull/150654 uses these changes to add decomposeK as a default autotuning choice for aten.mm in Inductor.

Using https://github.com/pytorch/pytorch/pull/150654 and a simple script:

```
import torch

def f(a, b):
    return torch.matmul(a, b)

def decompose_func(a_in, b_in):
    M, K = a_in.shape
    K, N = b_in.shape

    # TODO: Ideally we want to autotune over this parameter
    kPartitions = 256
    assert K % kPartitions == 0, "K must be divisible by Kmini"
    B = K // kPartitions

    a_reshaped = a_in.reshape(M, B, kPartitions).transpose(
        0, 1
      )  # Shape: (B, M, kPartitions)
    b_reshaped = b_in.reshape(B, kPartitions, N)  # Shape: (B, kPartitions, N)
    result = torch.bmm(a_reshaped, b_reshaped)  # Shape: (B, M, N)
    return result.sum(dim=0).to(torch.float16)  # Sum over B dimension, Shape: (M, N)

for k in [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768]:
    a = torch.randn(32, k, dtype=torch.float16, device="cuda", requires_grad=True)
    b = torch.randn(k, 32, dtype=torch.float16, device="cuda", requires_grad=True)

    compiled_res = torch.compile(f, dynamic=False)(a, b)
    decompose_res = decompose_func(a, b)

    print(f"Compiled mm result close to aten: {torch.allclose(f(a, b), compiled_res, atol=1e-5, rtol=0.5)}")
    print(f"Compiled mm result close to decompose: {torch.allclose(decompose_res, compiled_res, atol=1e-5, rtol=0.5)}")
```

we are able to autotune the decomposeK optimization to aten and the traditional Triton templates in Inductor. DecomposeK is faster than aten by about ~10% on average and > 4x speedup over the best Triton templates on an H100 machine, e.g.:

```
AUTOTUNE mm(32x28672, 28672x32)
  decompose_k_mm 0.0126 ms 100.0%
  mm 0.0144 ms 87.5%
  triton_mm_69 0.0579 ms 21.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_75 0.0677 ms 18.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
  triton_mm_76 0.0850 ms 14.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_68 0.1444 ms 8.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_72 0.1546 ms 8.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_74 0.1819 ms 6.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
  triton_mm_67 0.1917 ms 6.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_73 0.2766 ms 4.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
```

https://pastebin.com/g3FMaauT is the generated code from Inductor containing the subgraph decomposition for aten.mm.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150653
Approved by: https://github.com/eellison
2025-04-11 19:08:43 +00:00

158 lines
4.8 KiB
Python

import logging
from typing import Any, Callable
import torch
from torch._inductor import ir
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import Buffer, Layout
from torch._inductor.runtime.benchmarking import benchmarker
from torch._inductor.virtualized import V
log = logging.getLogger(__name__)
class SubgraphChoiceCaller(ir.ChoiceCaller):
"""
Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary
GraphModule. Compiles the Subgraph down to a module for benchmarking.
"""
def __init__(
self,
name: str,
input_nodes: list[Buffer],
layout: Layout,
description: str,
gm: torch.fx.GraphModule,
example_inputs: list[Any],
) -> None:
super().__init__(name, input_nodes, layout, description)
self.gm = gm
self.example_inputs = example_inputs
def __str__(self) -> str:
return f"SubgraphCaller({self.name})"
def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
# Codegen Subgraph for benchmarking
# Need GraphLowering instead of SubgraphLowering to generate
# fully callable module
import torch._inductor.config as inductor_config
from torch._inductor.graph import GraphLowering
bm_graph_lowering = GraphLowering(
gm=self.gm,
example_inputs=self.example_inputs,
shape_env=V.graph._shape_env,
cpp_wrapper=V.graph.cpp_wrapper,
aot_mode=V.graph.aot_mode,
extern_node_serializer=V.graph.extern_node_serializer,
is_inference=V.graph.is_inference,
is_backward=V.graph.is_backward,
name=f"benchmark_{self.name}",
)
with V.set_graph_handler(bm_graph_lowering):
# Don't bother autotuning on Triton here
with inductor_config.patch(
max_autotune=False,
max_autotune_gemm=False,
max_autotune_gemm_backends="ATEN",
):
bm_graph_lowering.run(*self.example_inputs)
mod = bm_graph_lowering.compile_to_module()
bm_func = mod.call
bm_func([*args])
return benchmarker.benchmark_gpu(lambda: bm_func([*args]))
def hash_key(self) -> str:
return "-".join(
[
self.name,
*[
str(arg.shape)
for arg in self.example_inputs
if isinstance(arg, torch.Tensor)
],
str(self.gm.graph),
]
)
def output_node(self) -> ir.TensorBox:
return ir.TensorBox.create(
ir.SubgraphBuffer(
layout=self.layout,
input_nodes=self.input_nodes,
gm=self.gm,
example_inputs=self.example_inputs,
subgraph_name=self.name,
)
)
def info_dict(self) -> dict[str, Any]:
"""Information returned here is logged to the autotune log file when that is enabled."""
return {
"backend": "subgraph",
"kernel_name": self.name,
}
def autoheuristic_id(self) -> str:
return f"subgraph_{self.name}"
class SubgraphTemplate(KernelTemplate):
"""
A template for subgraph evaluation to be used in autotuning.
This class allows creating customized subgraphs that can be appended
as choices during the autotuning process, enabling the selection of
optimal implementations for complex operations.
"""
def __init__(
self,
name: str,
make_fx_graph: Callable[..., Any],
):
"""
Initialize a subgraph template.
Args:
name: The name of this template
graph: The FX graph
"""
self.name = name
self.make_fx_graph = make_fx_graph
def generate( # type: ignore[override]
self,
input_nodes: list[Buffer],
layout: Layout,
example_inputs: list[Any],
**kwargs: Any,
) -> SubgraphChoiceCaller:
"""
Generate a SubgraphChoiceCaller instance for autotuning.
Args:
input_nodes: List of input nodes to the subgraph
layout: Memory layout information for the output
example_inputs: Example tensor inputs used to trace and benchmark the subgraph
**kwargs: Additional keyword arguments
Returns:
SubgraphChoiceCaller: A callable object that can be used for autotuning
"""
gm = self.make_fx_graph(*example_inputs)
return SubgraphChoiceCaller(
name=self.name,
input_nodes=input_nodes,
layout=layout,
description="",
gm=gm,
example_inputs=example_inputs,
)