mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add the option for providing a Subgraph as an autotuning choice in Inductor. This is crucial for implementing the split-k optimization for GEMMs by decomposing a mm -> bmm. https://github.com/pytorch/pytorch/pull/150654 uses these changes to add decomposeK as a default autotuning choice for aten.mm in Inductor. Using https://github.com/pytorch/pytorch/pull/150654 and a simple script: ``` import torch def f(a, b): return torch.matmul(a, b) def decompose_func(a_in, b_in): M, K = a_in.shape K, N = b_in.shape # TODO: Ideally we want to autotune over this parameter kPartitions = 256 assert K % kPartitions == 0, "K must be divisible by Kmini" B = K // kPartitions a_reshaped = a_in.reshape(M, B, kPartitions).transpose( 0, 1 ) # Shape: (B, M, kPartitions) b_reshaped = b_in.reshape(B, kPartitions, N) # Shape: (B, kPartitions, N) result = torch.bmm(a_reshaped, b_reshaped) # Shape: (B, M, N) return result.sum(dim=0).to(torch.float16) # Sum over B dimension, Shape: (M, N) for k in [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768]: a = torch.randn(32, k, dtype=torch.float16, device="cuda", requires_grad=True) b = torch.randn(k, 32, dtype=torch.float16, device="cuda", requires_grad=True) compiled_res = torch.compile(f, dynamic=False)(a, b) decompose_res = decompose_func(a, b) print(f"Compiled mm result close to aten: {torch.allclose(f(a, b), compiled_res, atol=1e-5, rtol=0.5)}") print(f"Compiled mm result close to decompose: {torch.allclose(decompose_res, compiled_res, atol=1e-5, rtol=0.5)}") ``` we are able to autotune the decomposeK optimization to aten and the traditional Triton templates in Inductor. DecomposeK is faster than aten by about ~10% on average and > 4x speedup over the best Triton templates on an H100 machine, e.g.: ``` AUTOTUNE mm(32x28672, 28672x32) decompose_k_mm 0.0126 ms 100.0% mm 0.0144 ms 87.5% triton_mm_69 0.0579 ms 21.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_75 0.0677 ms 18.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4 triton_mm_76 0.0850 ms 14.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_68 0.1444 ms 8.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_72 0.1546 ms 8.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_74 0.1819 ms 6.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4 triton_mm_67 0.1917 ms 6.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4 triton_mm_73 0.2766 ms 4.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 ``` https://pastebin.com/g3FMaauT is the generated code from Inductor containing the subgraph decomposition for aten.mm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150653 Approved by: https://github.com/eellison
158 lines
4.8 KiB
Python
158 lines
4.8 KiB
Python
import logging
|
|
from typing import Any, Callable
|
|
|
|
import torch
|
|
from torch._inductor import ir
|
|
from torch._inductor.codegen.common import KernelTemplate
|
|
from torch._inductor.ir import Buffer, Layout
|
|
from torch._inductor.runtime.benchmarking import benchmarker
|
|
from torch._inductor.virtualized import V
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class SubgraphChoiceCaller(ir.ChoiceCaller):
|
|
"""
|
|
Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary
|
|
GraphModule. Compiles the Subgraph down to a module for benchmarking.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
input_nodes: list[Buffer],
|
|
layout: Layout,
|
|
description: str,
|
|
gm: torch.fx.GraphModule,
|
|
example_inputs: list[Any],
|
|
) -> None:
|
|
super().__init__(name, input_nodes, layout, description)
|
|
self.gm = gm
|
|
self.example_inputs = example_inputs
|
|
|
|
def __str__(self) -> str:
|
|
return f"SubgraphCaller({self.name})"
|
|
|
|
def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
|
|
# Codegen Subgraph for benchmarking
|
|
# Need GraphLowering instead of SubgraphLowering to generate
|
|
# fully callable module
|
|
import torch._inductor.config as inductor_config
|
|
from torch._inductor.graph import GraphLowering
|
|
|
|
bm_graph_lowering = GraphLowering(
|
|
gm=self.gm,
|
|
example_inputs=self.example_inputs,
|
|
shape_env=V.graph._shape_env,
|
|
cpp_wrapper=V.graph.cpp_wrapper,
|
|
aot_mode=V.graph.aot_mode,
|
|
extern_node_serializer=V.graph.extern_node_serializer,
|
|
is_inference=V.graph.is_inference,
|
|
is_backward=V.graph.is_backward,
|
|
name=f"benchmark_{self.name}",
|
|
)
|
|
|
|
with V.set_graph_handler(bm_graph_lowering):
|
|
# Don't bother autotuning on Triton here
|
|
with inductor_config.patch(
|
|
max_autotune=False,
|
|
max_autotune_gemm=False,
|
|
max_autotune_gemm_backends="ATEN",
|
|
):
|
|
bm_graph_lowering.run(*self.example_inputs)
|
|
mod = bm_graph_lowering.compile_to_module()
|
|
bm_func = mod.call
|
|
bm_func([*args])
|
|
|
|
return benchmarker.benchmark_gpu(lambda: bm_func([*args]))
|
|
|
|
def hash_key(self) -> str:
|
|
return "-".join(
|
|
[
|
|
self.name,
|
|
*[
|
|
str(arg.shape)
|
|
for arg in self.example_inputs
|
|
if isinstance(arg, torch.Tensor)
|
|
],
|
|
str(self.gm.graph),
|
|
]
|
|
)
|
|
|
|
def output_node(self) -> ir.TensorBox:
|
|
return ir.TensorBox.create(
|
|
ir.SubgraphBuffer(
|
|
layout=self.layout,
|
|
input_nodes=self.input_nodes,
|
|
gm=self.gm,
|
|
example_inputs=self.example_inputs,
|
|
subgraph_name=self.name,
|
|
)
|
|
)
|
|
|
|
def info_dict(self) -> dict[str, Any]:
|
|
"""Information returned here is logged to the autotune log file when that is enabled."""
|
|
return {
|
|
"backend": "subgraph",
|
|
"kernel_name": self.name,
|
|
}
|
|
|
|
def autoheuristic_id(self) -> str:
|
|
return f"subgraph_{self.name}"
|
|
|
|
|
|
class SubgraphTemplate(KernelTemplate):
|
|
"""
|
|
A template for subgraph evaluation to be used in autotuning.
|
|
|
|
This class allows creating customized subgraphs that can be appended
|
|
as choices during the autotuning process, enabling the selection of
|
|
optimal implementations for complex operations.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
make_fx_graph: Callable[..., Any],
|
|
):
|
|
"""
|
|
Initialize a subgraph template.
|
|
|
|
Args:
|
|
name: The name of this template
|
|
graph: The FX graph
|
|
"""
|
|
self.name = name
|
|
self.make_fx_graph = make_fx_graph
|
|
|
|
def generate( # type: ignore[override]
|
|
self,
|
|
input_nodes: list[Buffer],
|
|
layout: Layout,
|
|
example_inputs: list[Any],
|
|
**kwargs: Any,
|
|
) -> SubgraphChoiceCaller:
|
|
"""
|
|
Generate a SubgraphChoiceCaller instance for autotuning.
|
|
|
|
Args:
|
|
input_nodes: List of input nodes to the subgraph
|
|
layout: Memory layout information for the output
|
|
example_inputs: Example tensor inputs used to trace and benchmark the subgraph
|
|
**kwargs: Additional keyword arguments
|
|
|
|
Returns:
|
|
SubgraphChoiceCaller: A callable object that can be used for autotuning
|
|
"""
|
|
gm = self.make_fx_graph(*example_inputs)
|
|
|
|
return SubgraphChoiceCaller(
|
|
name=self.name,
|
|
input_nodes=input_nodes,
|
|
layout=layout,
|
|
description="",
|
|
gm=gm,
|
|
example_inputs=example_inputs,
|
|
)
|