diff --git a/test/inductor/test_subgraph_choice.py b/test/inductor/test_subgraph_choice.py new file mode 100644 index 00000000000..e9ce30d4a9c --- /dev/null +++ b/test/inductor/test_subgraph_choice.py @@ -0,0 +1,118 @@ +# Owner(s): ["module: inductor"] +import functools + +import torch +from torch._dispatch.python import enable_python_dispatcher +from torch._inductor.codegen.subgraph import SubgraphTemplate +from torch._inductor.decomposition import select_decomp_table +from torch._inductor.ir import Buffer, FixedLayout +from torch._inductor.lowering import register_lowering +from torch._inductor.select_algorithm import ( + AlgorithmSelectorCache, + autotune_select_algorithm, +) +from torch._inductor.test_case import run_tests, TestCase +from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU + + +class TestSubgraphChoice(TestCase): + def setUp(self): + super().setUp() + + def _create_buffer(self, name, shape, dtype): + return Buffer( + name=name, + layout=FixedLayout(torch.device(f"{GPU_TYPE}:0"), dtype=dtype, size=shape), + ) + + def test_subgraph_decompose_k(self): + from torch._inductor.kernel.mm import aten_mm + from torch._inductor.kernel.mm_common import mm_args + + @torch.library.custom_op("mylib::matmul_decompose", mutates_args={}) + def matmul_decompose(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a @ b + + @matmul_decompose.register_fake + def _(a, b): + return a @ b + + def decomposeK(a, b, kPartitions): + m = a.shape[0] + n = b.shape[1] + k = a.shape[1] + + B = k // kPartitions + a_reshaped = torch.permute(a.reshape(m, B, kPartitions), (1, 0, 2)) + b_reshaped = b.reshape(B, kPartitions, n) + result = torch.bmm(a_reshaped, b_reshaped) + result_fp32 = result.to(torch.float32) + reduced_buf = torch.sum(result_fp32, 0) + return reduced_buf.to(a.dtype) + + mat1_shape, mat2_shape = (32, 4096), (4096, 32) + + @register_lowering(torch.ops.mylib.matmul_decompose) + def _(a, b): + _, _, _, layout, mat1, mat2 = mm_args(a, b) + + choices = [aten_mm.bind((mat1, mat2), layout)] + + # TODO (PaulZhang12): Once decomposeK lands in Inductor, move this + kPartitions = 256 + with enable_python_dispatcher(): + decompositions = select_decomp_table() + + decompose_k_subgraph_template = SubgraphTemplate( + name="decompose_k_mm", + make_fx_graph=make_fx( + functools.partial(decomposeK, kPartitions=kPartitions), + decompositions, + tracing_mode="real", + ), + ) + + mat1_tensor, mat2_tensor = ( + AlgorithmSelectorCache.benchmark_example_value(mat1), + AlgorithmSelectorCache.benchmark_example_value(mat2), + ) + decompose_k_subgraph_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + example_inputs=[mat1_tensor, mat2_tensor], + ) + + # Test benchmarking against aten + autotune_select_algorithm("test_subgraph_choice", choices, [a, b], layout) + + # Only return decomposeK case for codegen + choices = [choices[1]] + return autotune_select_algorithm( + "test_subgraph_choice", choices, [a, b], layout + ) + + a_in = torch.randn( + mat1_shape, dtype=torch.float16, device=torch.device(f"{GPU_TYPE}:0") + ) + b_in = torch.randn( + mat2_shape, dtype=torch.float16, device=torch.device(f"{GPU_TYPE}:0") + ) + + def func(mat1, mat2): + return torch.ops.mylib.matmul_decompose(mat1, mat2) + + compiled_func = torch.compile(func, mode="max-autotune", dynamic=False) + + res = compiled_func(a_in, b_in) + + # Check same results of compiled result and regular torch.mm + # Relax precision as decomposeK does first accumulation in fp16 + torch.testing.assert_close(res, a_in @ b_in, atol=1e-1, rtol=1e-1) + + +if __name__ == "__main__": + # Set env to make it work in CI. + if HAS_GPU and HAS_CPU: + run_tests() diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py new file mode 100644 index 00000000000..d7fb0bd38cd --- /dev/null +++ b/torch/_inductor/codegen/subgraph.py @@ -0,0 +1,157 @@ +import logging +from typing import Any, Callable + +import torch +from torch._inductor import ir +from torch._inductor.codegen.common import KernelTemplate +from torch._inductor.ir import Buffer, Layout +from torch._inductor.runtime.benchmarking import benchmarker +from torch._inductor.virtualized import V + + +log = logging.getLogger(__name__) + + +class SubgraphChoiceCaller(ir.ChoiceCaller): + """ + Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary + GraphModule. Compiles the Subgraph down to a module for benchmarking. + """ + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + description: str, + gm: torch.fx.GraphModule, + example_inputs: list[Any], + ) -> None: + super().__init__(name, input_nodes, layout, description) + self.gm = gm + self.example_inputs = example_inputs + + def __str__(self) -> str: + return f"SubgraphCaller({self.name})" + + def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: + # Codegen Subgraph for benchmarking + # Need GraphLowering instead of SubgraphLowering to generate + # fully callable module + import torch._inductor.config as inductor_config + from torch._inductor.graph import GraphLowering + + bm_graph_lowering = GraphLowering( + gm=self.gm, + example_inputs=self.example_inputs, + shape_env=V.graph._shape_env, + cpp_wrapper=V.graph.cpp_wrapper, + aot_mode=V.graph.aot_mode, + extern_node_serializer=V.graph.extern_node_serializer, + is_inference=V.graph.is_inference, + is_backward=V.graph.is_backward, + name=f"benchmark_{self.name}", + ) + + with V.set_graph_handler(bm_graph_lowering): + # Don't bother autotuning on Triton here + with inductor_config.patch( + max_autotune=False, + max_autotune_gemm=False, + max_autotune_gemm_backends="ATEN", + ): + bm_graph_lowering.run(*self.example_inputs) + mod = bm_graph_lowering.compile_to_module() + bm_func = mod.call + bm_func([*args]) + + return benchmarker.benchmark_gpu(lambda: bm_func([*args])) + + def hash_key(self) -> str: + return "-".join( + [ + self.name, + *[ + str(arg.shape) + for arg in self.example_inputs + if isinstance(arg, torch.Tensor) + ], + str(self.gm.graph), + ] + ) + + def output_node(self) -> ir.TensorBox: + return ir.TensorBox.create( + ir.SubgraphBuffer( + layout=self.layout, + input_nodes=self.input_nodes, + gm=self.gm, + example_inputs=self.example_inputs, + subgraph_name=self.name, + ) + ) + + def info_dict(self) -> dict[str, Any]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "subgraph", + "kernel_name": self.name, + } + + def autoheuristic_id(self) -> str: + return f"subgraph_{self.name}" + + +class SubgraphTemplate(KernelTemplate): + """ + A template for subgraph evaluation to be used in autotuning. + + This class allows creating customized subgraphs that can be appended + as choices during the autotuning process, enabling the selection of + optimal implementations for complex operations. + """ + + def __init__( + self, + name: str, + make_fx_graph: Callable[..., Any], + ): + """ + Initialize a subgraph template. + + Args: + name: The name of this template + graph: The FX graph + """ + self.name = name + self.make_fx_graph = make_fx_graph + + def generate( # type: ignore[override] + self, + input_nodes: list[Buffer], + layout: Layout, + example_inputs: list[Any], + **kwargs: Any, + ) -> SubgraphChoiceCaller: + """ + Generate a SubgraphChoiceCaller instance for autotuning. + + Args: + input_nodes: List of input nodes to the subgraph + layout: Memory layout information for the output + example_inputs: Example tensor inputs used to trace and benchmark the subgraph + **kwargs: Additional keyword arguments + + Returns: + SubgraphChoiceCaller: A callable object that can be used for autotuning + """ + gm = self.make_fx_graph(*example_inputs) + + return SubgraphChoiceCaller( + name=self.name, + input_nodes=input_nodes, + layout=layout, + description="", + gm=gm, + example_inputs=example_inputs, + ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 2cd5b745b0c..1ae2505e461 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5999,6 +5999,49 @@ class TMADescriptor(ExternKernel): wrapper.generate_tma_descriptor(self) +class SubgraphBuffer(ExternKernel): + def __init__( + self, + layout: Layout, + input_nodes: list[Buffer], + gm: torch.fx.GraphModule, + example_inputs: list[Any], + subgraph_name: str, + ): + super().__init__(None, layout, input_nodes) + self.gm = gm + self.example_inputs = example_inputs + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + self.subgraph = V.graph.make_subgraph( + self.gm, self.example_inputs, subgraph_name + ) + + import torch._inductor.config as inductor_config + + with V.set_graph_handler(self.subgraph): + # Don't bother autotuning on Triton here + with inductor_config.patch( # type: ignore[no-untyped-def] + max_autotune=False, + max_autotune_gemm=False, + max_autotune_gemm_backends="ATEN", + ): + self.subgraph.run(*self.example_inputs) + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + class CodegenGraph: + def __init__(self, graph: GraphLowering): + self.graph = graph + self.name = graph.name + + wrapper.codegen_subgraph( + CodegenGraph(self.subgraph), + [*[buffer.get_name() for buffer in self.inputs]], + [self.name], + ) + + class UserDefinedTritonKernel(ExternKernel): def get_kernel_and_metadata(self): # type: ignore[no-untyped-def] from triton.runtime.autotuner import Autotuner