mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Add Subgraph as a Autotuning Choice (#150653)
Add the option for providing a Subgraph as an autotuning choice in Inductor. This is crucial for implementing the split-k optimization for GEMMs by decomposing a mm -> bmm. https://github.com/pytorch/pytorch/pull/150654 uses these changes to add decomposeK as a default autotuning choice for aten.mm in Inductor. Using https://github.com/pytorch/pytorch/pull/150654 and a simple script: ``` import torch def f(a, b): return torch.matmul(a, b) def decompose_func(a_in, b_in): M, K = a_in.shape K, N = b_in.shape # TODO: Ideally we want to autotune over this parameter kPartitions = 256 assert K % kPartitions == 0, "K must be divisible by Kmini" B = K // kPartitions a_reshaped = a_in.reshape(M, B, kPartitions).transpose( 0, 1 ) # Shape: (B, M, kPartitions) b_reshaped = b_in.reshape(B, kPartitions, N) # Shape: (B, kPartitions, N) result = torch.bmm(a_reshaped, b_reshaped) # Shape: (B, M, N) return result.sum(dim=0).to(torch.float16) # Sum over B dimension, Shape: (M, N) for k in [4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768]: a = torch.randn(32, k, dtype=torch.float16, device="cuda", requires_grad=True) b = torch.randn(k, 32, dtype=torch.float16, device="cuda", requires_grad=True) compiled_res = torch.compile(f, dynamic=False)(a, b) decompose_res = decompose_func(a, b) print(f"Compiled mm result close to aten: {torch.allclose(f(a, b), compiled_res, atol=1e-5, rtol=0.5)}") print(f"Compiled mm result close to decompose: {torch.allclose(decompose_res, compiled_res, atol=1e-5, rtol=0.5)}") ``` we are able to autotune the decomposeK optimization to aten and the traditional Triton templates in Inductor. DecomposeK is faster than aten by about ~10% on average and > 4x speedup over the best Triton templates on an H100 machine, e.g.: ``` AUTOTUNE mm(32x28672, 28672x32) decompose_k_mm 0.0126 ms 100.0% mm 0.0144 ms 87.5% triton_mm_69 0.0579 ms 21.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_75 0.0677 ms 18.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4 triton_mm_76 0.0850 ms 14.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_68 0.1444 ms 8.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_72 0.1546 ms 8.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_74 0.1819 ms 6.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4 triton_mm_67 0.1917 ms 6.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4 triton_mm_73 0.2766 ms 4.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 ``` https://pastebin.com/g3FMaauT is the generated code from Inductor containing the subgraph decomposition for aten.mm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150653 Approved by: https://github.com/eellison
This commit is contained in:
parent
ad5e9065ac
commit
83ae61fd8e
118
test/inductor/test_subgraph_choice.py
Normal file
118
test/inductor/test_subgraph_choice.py
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
# Owner(s): ["module: inductor"]
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch._dispatch.python import enable_python_dispatcher
|
||||||
|
from torch._inductor.codegen.subgraph import SubgraphTemplate
|
||||||
|
from torch._inductor.decomposition import select_decomp_table
|
||||||
|
from torch._inductor.ir import Buffer, FixedLayout
|
||||||
|
from torch._inductor.lowering import register_lowering
|
||||||
|
from torch._inductor.select_algorithm import (
|
||||||
|
AlgorithmSelectorCache,
|
||||||
|
autotune_select_algorithm,
|
||||||
|
)
|
||||||
|
from torch._inductor.test_case import run_tests, TestCase
|
||||||
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubgraphChoice(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
def _create_buffer(self, name, shape, dtype):
|
||||||
|
return Buffer(
|
||||||
|
name=name,
|
||||||
|
layout=FixedLayout(torch.device(f"{GPU_TYPE}:0"), dtype=dtype, size=shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_subgraph_decompose_k(self):
|
||||||
|
from torch._inductor.kernel.mm import aten_mm
|
||||||
|
from torch._inductor.kernel.mm_common import mm_args
|
||||||
|
|
||||||
|
@torch.library.custom_op("mylib::matmul_decompose", mutates_args={})
|
||||||
|
def matmul_decompose(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||||
|
return a @ b
|
||||||
|
|
||||||
|
@matmul_decompose.register_fake
|
||||||
|
def _(a, b):
|
||||||
|
return a @ b
|
||||||
|
|
||||||
|
def decomposeK(a, b, kPartitions):
|
||||||
|
m = a.shape[0]
|
||||||
|
n = b.shape[1]
|
||||||
|
k = a.shape[1]
|
||||||
|
|
||||||
|
B = k // kPartitions
|
||||||
|
a_reshaped = torch.permute(a.reshape(m, B, kPartitions), (1, 0, 2))
|
||||||
|
b_reshaped = b.reshape(B, kPartitions, n)
|
||||||
|
result = torch.bmm(a_reshaped, b_reshaped)
|
||||||
|
result_fp32 = result.to(torch.float32)
|
||||||
|
reduced_buf = torch.sum(result_fp32, 0)
|
||||||
|
return reduced_buf.to(a.dtype)
|
||||||
|
|
||||||
|
mat1_shape, mat2_shape = (32, 4096), (4096, 32)
|
||||||
|
|
||||||
|
@register_lowering(torch.ops.mylib.matmul_decompose)
|
||||||
|
def _(a, b):
|
||||||
|
_, _, _, layout, mat1, mat2 = mm_args(a, b)
|
||||||
|
|
||||||
|
choices = [aten_mm.bind((mat1, mat2), layout)]
|
||||||
|
|
||||||
|
# TODO (PaulZhang12): Once decomposeK lands in Inductor, move this
|
||||||
|
kPartitions = 256
|
||||||
|
with enable_python_dispatcher():
|
||||||
|
decompositions = select_decomp_table()
|
||||||
|
|
||||||
|
decompose_k_subgraph_template = SubgraphTemplate(
|
||||||
|
name="decompose_k_mm",
|
||||||
|
make_fx_graph=make_fx(
|
||||||
|
functools.partial(decomposeK, kPartitions=kPartitions),
|
||||||
|
decompositions,
|
||||||
|
tracing_mode="real",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
mat1_tensor, mat2_tensor = (
|
||||||
|
AlgorithmSelectorCache.benchmark_example_value(mat1),
|
||||||
|
AlgorithmSelectorCache.benchmark_example_value(mat2),
|
||||||
|
)
|
||||||
|
decompose_k_subgraph_template.maybe_append_choice(
|
||||||
|
choices,
|
||||||
|
input_nodes=(mat1, mat2),
|
||||||
|
layout=layout,
|
||||||
|
example_inputs=[mat1_tensor, mat2_tensor],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test benchmarking against aten
|
||||||
|
autotune_select_algorithm("test_subgraph_choice", choices, [a, b], layout)
|
||||||
|
|
||||||
|
# Only return decomposeK case for codegen
|
||||||
|
choices = [choices[1]]
|
||||||
|
return autotune_select_algorithm(
|
||||||
|
"test_subgraph_choice", choices, [a, b], layout
|
||||||
|
)
|
||||||
|
|
||||||
|
a_in = torch.randn(
|
||||||
|
mat1_shape, dtype=torch.float16, device=torch.device(f"{GPU_TYPE}:0")
|
||||||
|
)
|
||||||
|
b_in = torch.randn(
|
||||||
|
mat2_shape, dtype=torch.float16, device=torch.device(f"{GPU_TYPE}:0")
|
||||||
|
)
|
||||||
|
|
||||||
|
def func(mat1, mat2):
|
||||||
|
return torch.ops.mylib.matmul_decompose(mat1, mat2)
|
||||||
|
|
||||||
|
compiled_func = torch.compile(func, mode="max-autotune", dynamic=False)
|
||||||
|
|
||||||
|
res = compiled_func(a_in, b_in)
|
||||||
|
|
||||||
|
# Check same results of compiled result and regular torch.mm
|
||||||
|
# Relax precision as decomposeK does first accumulation in fp16
|
||||||
|
torch.testing.assert_close(res, a_in @ b_in, atol=1e-1, rtol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Set env to make it work in CI.
|
||||||
|
if HAS_GPU and HAS_CPU:
|
||||||
|
run_tests()
|
||||||
157
torch/_inductor/codegen/subgraph.py
Normal file
157
torch/_inductor/codegen/subgraph.py
Normal file
|
|
@ -0,0 +1,157 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch._inductor import ir
|
||||||
|
from torch._inductor.codegen.common import KernelTemplate
|
||||||
|
from torch._inductor.ir import Buffer, Layout
|
||||||
|
from torch._inductor.runtime.benchmarking import benchmarker
|
||||||
|
from torch._inductor.virtualized import V
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SubgraphChoiceCaller(ir.ChoiceCaller):
|
||||||
|
"""
|
||||||
|
Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary
|
||||||
|
GraphModule. Compiles the Subgraph down to a module for benchmarking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
input_nodes: list[Buffer],
|
||||||
|
layout: Layout,
|
||||||
|
description: str,
|
||||||
|
gm: torch.fx.GraphModule,
|
||||||
|
example_inputs: list[Any],
|
||||||
|
) -> None:
|
||||||
|
super().__init__(name, input_nodes, layout, description)
|
||||||
|
self.gm = gm
|
||||||
|
self.example_inputs = example_inputs
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"SubgraphCaller({self.name})"
|
||||||
|
|
||||||
|
def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
|
||||||
|
# Codegen Subgraph for benchmarking
|
||||||
|
# Need GraphLowering instead of SubgraphLowering to generate
|
||||||
|
# fully callable module
|
||||||
|
import torch._inductor.config as inductor_config
|
||||||
|
from torch._inductor.graph import GraphLowering
|
||||||
|
|
||||||
|
bm_graph_lowering = GraphLowering(
|
||||||
|
gm=self.gm,
|
||||||
|
example_inputs=self.example_inputs,
|
||||||
|
shape_env=V.graph._shape_env,
|
||||||
|
cpp_wrapper=V.graph.cpp_wrapper,
|
||||||
|
aot_mode=V.graph.aot_mode,
|
||||||
|
extern_node_serializer=V.graph.extern_node_serializer,
|
||||||
|
is_inference=V.graph.is_inference,
|
||||||
|
is_backward=V.graph.is_backward,
|
||||||
|
name=f"benchmark_{self.name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
with V.set_graph_handler(bm_graph_lowering):
|
||||||
|
# Don't bother autotuning on Triton here
|
||||||
|
with inductor_config.patch(
|
||||||
|
max_autotune=False,
|
||||||
|
max_autotune_gemm=False,
|
||||||
|
max_autotune_gemm_backends="ATEN",
|
||||||
|
):
|
||||||
|
bm_graph_lowering.run(*self.example_inputs)
|
||||||
|
mod = bm_graph_lowering.compile_to_module()
|
||||||
|
bm_func = mod.call
|
||||||
|
bm_func([*args])
|
||||||
|
|
||||||
|
return benchmarker.benchmark_gpu(lambda: bm_func([*args]))
|
||||||
|
|
||||||
|
def hash_key(self) -> str:
|
||||||
|
return "-".join(
|
||||||
|
[
|
||||||
|
self.name,
|
||||||
|
*[
|
||||||
|
str(arg.shape)
|
||||||
|
for arg in self.example_inputs
|
||||||
|
if isinstance(arg, torch.Tensor)
|
||||||
|
],
|
||||||
|
str(self.gm.graph),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def output_node(self) -> ir.TensorBox:
|
||||||
|
return ir.TensorBox.create(
|
||||||
|
ir.SubgraphBuffer(
|
||||||
|
layout=self.layout,
|
||||||
|
input_nodes=self.input_nodes,
|
||||||
|
gm=self.gm,
|
||||||
|
example_inputs=self.example_inputs,
|
||||||
|
subgraph_name=self.name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def info_dict(self) -> dict[str, Any]:
|
||||||
|
"""Information returned here is logged to the autotune log file when that is enabled."""
|
||||||
|
return {
|
||||||
|
"backend": "subgraph",
|
||||||
|
"kernel_name": self.name,
|
||||||
|
}
|
||||||
|
|
||||||
|
def autoheuristic_id(self) -> str:
|
||||||
|
return f"subgraph_{self.name}"
|
||||||
|
|
||||||
|
|
||||||
|
class SubgraphTemplate(KernelTemplate):
|
||||||
|
"""
|
||||||
|
A template for subgraph evaluation to be used in autotuning.
|
||||||
|
|
||||||
|
This class allows creating customized subgraphs that can be appended
|
||||||
|
as choices during the autotuning process, enabling the selection of
|
||||||
|
optimal implementations for complex operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
make_fx_graph: Callable[..., Any],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize a subgraph template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of this template
|
||||||
|
graph: The FX graph
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
self.make_fx_graph = make_fx_graph
|
||||||
|
|
||||||
|
def generate( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
input_nodes: list[Buffer],
|
||||||
|
layout: Layout,
|
||||||
|
example_inputs: list[Any],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> SubgraphChoiceCaller:
|
||||||
|
"""
|
||||||
|
Generate a SubgraphChoiceCaller instance for autotuning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_nodes: List of input nodes to the subgraph
|
||||||
|
layout: Memory layout information for the output
|
||||||
|
example_inputs: Example tensor inputs used to trace and benchmark the subgraph
|
||||||
|
**kwargs: Additional keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SubgraphChoiceCaller: A callable object that can be used for autotuning
|
||||||
|
"""
|
||||||
|
gm = self.make_fx_graph(*example_inputs)
|
||||||
|
|
||||||
|
return SubgraphChoiceCaller(
|
||||||
|
name=self.name,
|
||||||
|
input_nodes=input_nodes,
|
||||||
|
layout=layout,
|
||||||
|
description="",
|
||||||
|
gm=gm,
|
||||||
|
example_inputs=example_inputs,
|
||||||
|
)
|
||||||
|
|
@ -5999,6 +5999,49 @@ class TMADescriptor(ExternKernel):
|
||||||
wrapper.generate_tma_descriptor(self)
|
wrapper.generate_tma_descriptor(self)
|
||||||
|
|
||||||
|
|
||||||
|
class SubgraphBuffer(ExternKernel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layout: Layout,
|
||||||
|
input_nodes: list[Buffer],
|
||||||
|
gm: torch.fx.GraphModule,
|
||||||
|
example_inputs: list[Any],
|
||||||
|
subgraph_name: str,
|
||||||
|
):
|
||||||
|
super().__init__(None, layout, input_nodes)
|
||||||
|
self.gm = gm
|
||||||
|
self.example_inputs = example_inputs
|
||||||
|
self.name = V.graph.register_buffer(self)
|
||||||
|
V.graph.register_operation(self)
|
||||||
|
|
||||||
|
self.subgraph = V.graph.make_subgraph(
|
||||||
|
self.gm, self.example_inputs, subgraph_name
|
||||||
|
)
|
||||||
|
|
||||||
|
import torch._inductor.config as inductor_config
|
||||||
|
|
||||||
|
with V.set_graph_handler(self.subgraph):
|
||||||
|
# Don't bother autotuning on Triton here
|
||||||
|
with inductor_config.patch( # type: ignore[no-untyped-def]
|
||||||
|
max_autotune=False,
|
||||||
|
max_autotune_gemm=False,
|
||||||
|
max_autotune_gemm_backends="ATEN",
|
||||||
|
):
|
||||||
|
self.subgraph.run(*self.example_inputs)
|
||||||
|
|
||||||
|
def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def]
|
||||||
|
class CodegenGraph:
|
||||||
|
def __init__(self, graph: GraphLowering):
|
||||||
|
self.graph = graph
|
||||||
|
self.name = graph.name
|
||||||
|
|
||||||
|
wrapper.codegen_subgraph(
|
||||||
|
CodegenGraph(self.subgraph),
|
||||||
|
[*[buffer.get_name() for buffer in self.inputs]],
|
||||||
|
[self.name],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedTritonKernel(ExternKernel):
|
class UserDefinedTritonKernel(ExternKernel):
|
||||||
def get_kernel_and_metadata(self): # type: ignore[no-untyped-def]
|
def get_kernel_and_metadata(self): # type: ignore[no-untyped-def]
|
||||||
from triton.runtime.autotuner import Autotuner
|
from triton.runtime.autotuner import Autotuner
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user