From cf0a0dcb0afa5e84b95461cc542f862b51ca96bf Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Mon, 11 Aug 2025 04:23:23 -0700 Subject: [PATCH] Make user defined Triton kernels serializable for fx_graph_runnable (#160002) Resolves issue https://github.com/pytorch/pytorch/issues/153475 where `fx_graph_runnable` didn't work with user defined triton kernels. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160002 Approved by: https://github.com/eellison --- test/dynamo/test_fx_graph_runnable.py | 88 +++++++++++++++++++++++++++ torch/_dynamo/repro/after_aot.py | 77 +++++++++++++++++++++++ 2 files changed, 165 insertions(+) diff --git a/test/dynamo/test_fx_graph_runnable.py b/test/dynamo/test_fx_graph_runnable.py index d5ad0c160c4..47e9ee3cb88 100644 --- a/test/dynamo/test_fx_graph_runnable.py +++ b/test/dynamo/test_fx_graph_runnable.py @@ -11,12 +11,65 @@ import torch.distributed as dist from torch._inductor.codecache import WritableTempFile from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE +from torch.utils._triton import has_triton if torch.distributed.is_available(): from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore +if has_triton(): + import triton + import triton.language as tl + + def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + @triton.jit + def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.atomic_add(output_ptr + offsets, output, mask=mask) + + @triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE": 1024}, + num_warps=4, + num_stages=2, + pre_hook=init_to_zero("output_ptr"), + ) + ], + pre_hook=init_to_zero("output_ptr"), + post_hook=init_to_zero("output_ptr"), + key=["n_elements"], + ) + @triton.jit + def add_kernel_autotune( + x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.atomic_add(output_ptr + offsets, output, mask=mask) + + +from torch.testing._internal.inductor_utils import GPU_TYPE +from torch.testing._internal.triton_utils import requires_gpu + class FxGraphRunnableArtifactFilter(logging.Filter): def filter(self, record): @@ -100,6 +153,41 @@ class FxGraphRunnableTest(TestCase): torch.compile(f)(torch.randn(4)) self._exec_and_verify_payload() + @unittest.skipUnless(has_triton(), "Triton not available") + def test_user_defined_triton_kernel_autotune(self): + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.ones(x.shape, device=x.device, dtype=x.dtype) + n_elements = output.numel() + + def grid( + meta, + ): + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + + add_kernel_autotune[grid](x, y, output, n_elements) + return output + + x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + + torch.compile(add)(x, y) + self._exec_and_verify_payload() + + @unittest.skipUnless(has_triton(), "Triton not available") + @requires_gpu + def test_user_defined_triton_kernel(self): + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.ones(x.shape, device=x.device, dtype=x.dtype) + n_elements = x.numel() + add_kernel[n_elements,](x, y, output, n_elements, BLOCK_SIZE=4) + return output + + x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) + + torch.compile(add)(x, y) + self._exec_and_verify_payload() + def test_two_inputs_matmul(self): def f(a, b): return (a @ b).relu() diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 71f552a83b4..136d2af1a60 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -34,6 +34,24 @@ from tempfile import TemporaryFile from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union from typing_extensions import Unpack +from torch.utils._triton import has_triton + + +if has_triton(): + from triton.runtime.autotuner import Autotuner, Heuristics + from triton.runtime.jit import JITFunction +else: + + class Autotuner: # type: ignore[no-redef] + pass + + class JITFunction: # type: ignore[no-redef] + pass + + class Heuristics: # type: ignore[no-redef] + pass + + import torch import torch.fx as fx import torch.nn as nn @@ -58,6 +76,7 @@ from torch._dynamo.debug_utils import ( ) from torch._dynamo.utils import clone_inputs, counters, same from torch._environment import is_fbcode +from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.output_code import OutputCode from torch._library.fake_class_registry import FakeScriptObject @@ -302,6 +321,16 @@ from torch.testing._internal.distributed.fake_pg import FakeStore """ ).strip() + triton_imports = "" + + if len(kernel_side_table.id_to_kernel) > 0: + triton_imports = textwrap.dedent( + """ +import triton +import triton.language as tl + """ + ).strip() + model_str = textwrap.dedent( f""" {generate_env_vars_string(stable_output=stable_output)} @@ -312,6 +341,7 @@ from torch._dynamo.testing import rand_strided from math import inf import torch._inductor.inductor_prims {distributed_imports} +{triton_imports} {generate_config_string(stable_output=stable_output)} @@ -330,6 +360,53 @@ isolate_fails_code_str = None model_str += f"# torch git version: {torch.version.git_version}\n\n\n" model_str += _cuda_system_info_comment() + kernel_side_table_prefix = ( + "torch._higher_order_ops.triton_kernel_wrap.kernel_side_table" + ) + # Track which grid entry corresponds to the best config + for id in kernel_side_table.id_to_kernel: + kernel = kernel_side_table.get_kernel(id) + + if isinstance(kernel, Autotuner): + if isinstance(kernel.fn, Heuristics): + model_str += "ERROR: Repro will not work as intended, " + model_str += ( + "triton.runtime.autotuner.Heuristics is not currently supported\n" + ) + break + + config_strs = [] + for kernel_config in kernel.configs: + config_strs.append(f"""triton.Config( + {str(kernel_config.kwargs)}, + num_warps={kernel_config.num_warps}, + num_stages={kernel_config.num_stages}, + )""") + + config_str = ",".join(config_strs) + model_str += textwrap.dedent(f""" + @triton.autotune( + configs=[ + {config_str} + ], + key=[] + ) + """).strip() + + model_str += "\n@triton.jit\n" + src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src + fn_name = ( + kernel._fn_name if isinstance(kernel, JITFunction) else kernel.fn._fn_name + ) + fn_name = fn_name.split(".")[-1] + + model_str += src_code + model_str += "\n" + model_str += f"{kernel_side_table_prefix}.add_kernel({fn_name})\n" + + if len(kernel_side_table.constant_args) > 0: + model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n" + model_str += NNModuleToString.convert(gm) writer = InputWriter(save_dir, stable_hash=stable_hash)