mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Add Subgraph as a Autotuning Choice (#150653)
Add the option for providing a Subgraph as an autotuning choice in Inductor. This is crucial for implementing the split-k optimization for GEMMs by decomposing a mm -> bmm. https://github.com/pytorch/pytorch/pull/150654 uses these changes to add decomposeK as a default autotuning choice for aten.mm in Inductor. Using https://github.com/pytorch/pytorch/pull/150654 and a simple script: ``` import torch def f(a, b): return torch.matmul(a, b) def decompose_func(a_in, b_in): M, K = a_in.shape K, N = b_in.shape # TODO: Ideally we want to autotune over this parameter kPartitions = 256 assert K % kPartitions == 0, "K must be divisible by Kmini" B = K // kPartitions a_reshaped = a_in.reshape(M, B, kPartitions).transpose( 0, 1 ) # Shape: (B, M, kPartitions) b_reshaped = b_in.reshape(B, kPartitions, N) # Shape: (B, kPartitions, N) result = torch.bmm(a_reshaped, b_reshaped) # Shape: (B, M, N) return result.sum(dim=0).to(torch.float16) # Sum over B dimension, Shape: (M, N) for k in [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768]: a = torch.randn(32, k, dtype=torch.float16, device="cuda", requires_grad=True) b = torch.randn(k, 32, dtype=torch.float16, device="cuda", requires_grad=True) compiled_res = torch.compile(f, dynamic=False)(a, b) decompose_res = decompose_func(a, b) print(f"Compiled mm result close to aten: {torch.allclose(f(a, b), compiled_res, atol=1e-5, rtol=0.5)}") print(f"Compiled mm result close to decompose: {torch.allclose(decompose_res, compiled_res, atol=1e-5, rtol=0.5)}") ``` we are able to autotune the decomposeK optimization to aten and the traditional Triton templates in Inductor. DecomposeK is faster than aten by about ~10% on average and > 4x speedup over the best Triton templates on an H100 machine, e.g.: ``` AUTOTUNE mm(32x28672, 28672x32) decompose_k_mm 0.0126 ms 100.0% mm 0.0144 ms 87.5% triton_mm_69 0.0579 ms 21.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_75 0.0677 ms 18.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4 triton_mm_76 0.0850 ms 14.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_68 0.1444 ms 8.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_72 0.1546 ms 8.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_74 0.1819 ms 6.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4 triton_mm_67 0.1917 ms 6.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4 triton_mm_73 0.2766 ms 4.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 ``` https://pastebin.com/g3FMaauT is the generated code from Inductor containing the subgraph decomposition for aten.mm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150653 Approved by: https://github.com/eellison
This commit is contained in:
parent
ad5e9065ac
commit
83ae61fd8e
118
test/inductor/test_subgraph_choice.py
Normal file
118
test/inductor/test_subgraph_choice.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._inductor.codegen.subgraph import SubgraphTemplate
|
||||
from torch._inductor.decomposition import select_decomp_table
|
||||
from torch._inductor.ir import Buffer, FixedLayout
|
||||
from torch._inductor.lowering import register_lowering
|
||||
from torch._inductor.select_algorithm import (
|
||||
AlgorithmSelectorCache,
|
||||
autotune_select_algorithm,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
||||
|
||||
|
||||
class TestSubgraphChoice(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
def _create_buffer(self, name, shape, dtype):
|
||||
return Buffer(
|
||||
name=name,
|
||||
layout=FixedLayout(torch.device(f"{GPU_TYPE}:0"), dtype=dtype, size=shape),
|
||||
)
|
||||
|
||||
def test_subgraph_decompose_k(self):
|
||||
from torch._inductor.kernel.mm import aten_mm
|
||||
from torch._inductor.kernel.mm_common import mm_args
|
||||
|
||||
@torch.library.custom_op("mylib::matmul_decompose", mutates_args={})
|
||||
def matmul_decompose(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
return a @ b
|
||||
|
||||
@matmul_decompose.register_fake
|
||||
def _(a, b):
|
||||
return a @ b
|
||||
|
||||
def decomposeK(a, b, kPartitions):
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
k = a.shape[1]
|
||||
|
||||
B = k // kPartitions
|
||||
a_reshaped = torch.permute(a.reshape(m, B, kPartitions), (1, 0, 2))
|
||||
b_reshaped = b.reshape(B, kPartitions, n)
|
||||
result = torch.bmm(a_reshaped, b_reshaped)
|
||||
result_fp32 = result.to(torch.float32)
|
||||
reduced_buf = torch.sum(result_fp32, 0)
|
||||
return reduced_buf.to(a.dtype)
|
||||
|
||||
mat1_shape, mat2_shape = (32, 4096), (4096, 32)
|
||||
|
||||
@register_lowering(torch.ops.mylib.matmul_decompose)
|
||||
def _(a, b):
|
||||
_, _, _, layout, mat1, mat2 = mm_args(a, b)
|
||||
|
||||
choices = [aten_mm.bind((mat1, mat2), layout)]
|
||||
|
||||
# TODO (PaulZhang12): Once decomposeK lands in Inductor, move this
|
||||
kPartitions = 256
|
||||
with enable_python_dispatcher():
|
||||
decompositions = select_decomp_table()
|
||||
|
||||
decompose_k_subgraph_template = SubgraphTemplate(
|
||||
name="decompose_k_mm",
|
||||
make_fx_graph=make_fx(
|
||||
functools.partial(decomposeK, kPartitions=kPartitions),
|
||||
decompositions,
|
||||
tracing_mode="real",
|
||||
),
|
||||
)
|
||||
|
||||
mat1_tensor, mat2_tensor = (
|
||||
AlgorithmSelectorCache.benchmark_example_value(mat1),
|
||||
AlgorithmSelectorCache.benchmark_example_value(mat2),
|
||||
)
|
||||
decompose_k_subgraph_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(mat1, mat2),
|
||||
layout=layout,
|
||||
example_inputs=[mat1_tensor, mat2_tensor],
|
||||
)
|
||||
|
||||
# Test benchmarking against aten
|
||||
autotune_select_algorithm("test_subgraph_choice", choices, [a, b], layout)
|
||||
|
||||
# Only return decomposeK case for codegen
|
||||
choices = [choices[1]]
|
||||
return autotune_select_algorithm(
|
||||
"test_subgraph_choice", choices, [a, b], layout
|
||||
)
|
||||
|
||||
a_in = torch.randn(
|
||||
mat1_shape, dtype=torch.float16, device=torch.device(f"{GPU_TYPE}:0")
|
||||
)
|
||||
b_in = torch.randn(
|
||||
mat2_shape, dtype=torch.float16, device=torch.device(f"{GPU_TYPE}:0")
|
||||
)
|
||||
|
||||
def func(mat1, mat2):
|
||||
return torch.ops.mylib.matmul_decompose(mat1, mat2)
|
||||
|
||||
compiled_func = torch.compile(func, mode="max-autotune", dynamic=False)
|
||||
|
||||
res = compiled_func(a_in, b_in)
|
||||
|
||||
# Check same results of compiled result and regular torch.mm
|
||||
# Relax precision as decomposeK does first accumulation in fp16
|
||||
torch.testing.assert_close(res, a_in @ b_in, atol=1e-1, rtol=1e-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set env to make it work in CI.
|
||||
if HAS_GPU and HAS_CPU:
|
||||
run_tests()
|
||||
157
torch/_inductor/codegen/subgraph.py
Normal file
157
torch/_inductor/codegen/subgraph.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from torch._inductor import ir
|
||||
from torch._inductor.codegen.common import KernelTemplate
|
||||
from torch._inductor.ir import Buffer, Layout
|
||||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubgraphChoiceCaller(ir.ChoiceCaller):
|
||||
"""
|
||||
Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary
|
||||
GraphModule. Compiles the Subgraph down to a module for benchmarking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
input_nodes: list[Buffer],
|
||||
layout: Layout,
|
||||
description: str,
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
) -> None:
|
||||
super().__init__(name, input_nodes, layout, description)
|
||||
self.gm = gm
|
||||
self.example_inputs = example_inputs
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SubgraphCaller({self.name})"
|
||||
|
||||
def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
|
||||
# Codegen Subgraph for benchmarking
|
||||
# Need GraphLowering instead of SubgraphLowering to generate
|
||||
# fully callable module
|
||||
import torch._inductor.config as inductor_config
|
||||
from torch._inductor.graph import GraphLowering
|
||||
|
||||
bm_graph_lowering = GraphLowering(
|
||||
gm=self.gm,
|
||||
example_inputs=self.example_inputs,
|
||||
shape_env=V.graph._shape_env,
|
||||
cpp_wrapper=V.graph.cpp_wrapper,
|
||||
aot_mode=V.graph.aot_mode,
|
||||
extern_node_serializer=V.graph.extern_node_serializer,
|
||||
is_inference=V.graph.is_inference,
|
||||
is_backward=V.graph.is_backward,
|
||||
name=f"benchmark_{self.name}",
|
||||
)
|
||||
|
||||
with V.set_graph_handler(bm_graph_lowering):
|
||||
# Don't bother autotuning on Triton here
|
||||
with inductor_config.patch(
|
||||
max_autotune=False,
|
||||
max_autotune_gemm=False,
|
||||
max_autotune_gemm_backends="ATEN",
|
||||
):
|
||||
bm_graph_lowering.run(*self.example_inputs)
|
||||
mod = bm_graph_lowering.compile_to_module()
|
||||
bm_func = mod.call
|
||||
bm_func([*args])
|
||||
|
||||
return benchmarker.benchmark_gpu(lambda: bm_func([*args]))
|
||||
|
||||
def hash_key(self) -> str:
|
||||
return "-".join(
|
||||
[
|
||||
self.name,
|
||||
*[
|
||||
str(arg.shape)
|
||||
for arg in self.example_inputs
|
||||
if isinstance(arg, torch.Tensor)
|
||||
],
|
||||
str(self.gm.graph),
|
||||
]
|
||||
)
|
||||
|
||||
def output_node(self) -> ir.TensorBox:
|
||||
return ir.TensorBox.create(
|
||||
ir.SubgraphBuffer(
|
||||
layout=self.layout,
|
||||
input_nodes=self.input_nodes,
|
||||
gm=self.gm,
|
||||
example_inputs=self.example_inputs,
|
||||
subgraph_name=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
def info_dict(self) -> dict[str, Any]:
|
||||
"""Information returned here is logged to the autotune log file when that is enabled."""
|
||||
return {
|
||||
"backend": "subgraph",
|
||||
"kernel_name": self.name,
|
||||
}
|
||||
|
||||
def autoheuristic_id(self) -> str:
|
||||
return f"subgraph_{self.name}"
|
||||
|
||||
|
||||
class SubgraphTemplate(KernelTemplate):
|
||||
"""
|
||||
A template for subgraph evaluation to be used in autotuning.
|
||||
|
||||
This class allows creating customized subgraphs that can be appended
|
||||
as choices during the autotuning process, enabling the selection of
|
||||
optimal implementations for complex operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
make_fx_graph: Callable[..., Any],
|
||||
):
|
||||
"""
|
||||
Initialize a subgraph template.
|
||||
|
||||
Args:
|
||||
name: The name of this template
|
||||
graph: The FX graph
|
||||
"""
|
||||
self.name = name
|
||||
self.make_fx_graph = make_fx_graph
|
||||
|
||||
def generate( # type: ignore[override]
|
||||
self,
|
||||
input_nodes: list[Buffer],
|
||||
layout: Layout,
|
||||
example_inputs: list[Any],
|
||||
**kwargs: Any,
|
||||
) -> SubgraphChoiceCaller:
|
||||
"""
|
||||
Generate a SubgraphChoiceCaller instance for autotuning.
|
||||
|
||||
Args:
|
||||
input_nodes: List of input nodes to the subgraph
|
||||
layout: Memory layout information for the output
|
||||
example_inputs: Example tensor inputs used to trace and benchmark the subgraph
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
SubgraphChoiceCaller: A callable object that can be used for autotuning
|
||||
"""
|
||||
gm = self.make_fx_graph(*example_inputs)
|
||||
|
||||
return SubgraphChoiceCaller(
|
||||
name=self.name,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
description="",
|
||||
gm=gm,
|
||||
example_inputs=example_inputs,
|
||||
)
|
||||
|
|
@ -5999,6 +5999,49 @@ class TMADescriptor(ExternKernel):
|
|||
wrapper.generate_tma_descriptor(self)
|
||||
|
||||
|
||||
class SubgraphBuffer(ExternKernel):
|
||||
def __init__(
|
||||
self,
|
||||
layout: Layout,
|
||||
input_nodes: list[Buffer],
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
subgraph_name: str,
|
||||
):
|
||||
super().__init__(None, layout, input_nodes)
|
||||
self.gm = gm
|
||||
self.example_inputs = example_inputs
|
||||
self.name = V.graph.register_buffer(self)
|
||||
V.graph.register_operation(self)
|
||||
|
||||
self.subgraph = V.graph.make_subgraph(
|
||||
self.gm, self.example_inputs, subgraph_name
|
||||
)
|
||||
|
||||
import torch._inductor.config as inductor_config
|
||||
|
||||
with V.set_graph_handler(self.subgraph):
|
||||
# Don't bother autotuning on Triton here
|
||||
with inductor_config.patch( # type: ignore[no-untyped-def]
|
||||
max_autotune=False,
|
||||
max_autotune_gemm=False,
|
||||
max_autotune_gemm_backends="ATEN",
|
||||
):
|
||||
self.subgraph.run(*self.example_inputs)
|
||||
|
||||
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||
class CodegenGraph:
|
||||
def __init__(self, graph: GraphLowering):
|
||||
self.graph = graph
|
||||
self.name = graph.name
|
||||
|
||||
wrapper.codegen_subgraph(
|
||||
CodegenGraph(self.subgraph),
|
||||
[*[buffer.get_name() for buffer in self.inputs]],
|
||||
[self.name],
|
||||
)
|
||||
|
||||
|
||||
class UserDefinedTritonKernel(ExternKernel):
|
||||
def get_kernel_and_metadata(self): # type: ignore[no-untyped-def]
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user