[Inductor-FX] Support custom triton kernels (#161474)

# Feature
Add support for custom Triton kernels to the FX backend. This turned out not to require any new features, except for a minor change to handle `tl.constexpr` arguments which are not part of the autotuning config.

# Caveat

This may not cover every possible case. For example, we might need more features for autotuning custom Triton code. This PR entirely skips the [custom codegen ](https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/triton_kernel_wrap.py#L1034-L1039) for user-defined grid functions, but there may be edge cases requiring this logic. However, this PR seems to do a reasonable job as many of the grids end up being written into Inductor/Triton metadata and don't require special codegen.

As a follow up, I'm planning to test this against all of AOTI's custom Triton kernel tests.

# Test plan
Added a CI test using a custom Triton kernel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161474
Approved by: https://github.com/angelayi
This commit is contained in:
Blaine Burton Rister 2025-08-27 00:15:17 +00:00 committed by PyTorch MergeBot
parent dbc903a94a
commit 9de9d25f8d
2 changed files with 69 additions and 1 deletions

View File

@ -35,6 +35,11 @@ from torch.testing._internal.inductor_utils import (
)
if HAS_GPU:
import triton
import triton.language as tl
@requires_gpu()
@config.patch(
compile_threads=1,
@ -544,6 +549,37 @@ class FxirTestCase(InductorTestCase):
if use_dynamic_shapes:
self.assertEqual(type(shape[0]), torch.fx.Node)
def test_custom_triton(self):
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
return output
args = [torch.randn(32, device=self.device) for _ in range(2)]
self._compile_and_check(add, args)
def test_output_slice_view(self):
"""
Test when the output is a view of the input.

View File

@ -33,6 +33,7 @@ from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
from torch.utils._sympy.reference import OptimizedPythonReferenceAnalysis
from .. import config, ir
from ..runtime.triton_compat import Config
from ..utils import LineContext
from .common import (
CodegenSymbol,
@ -700,9 +701,40 @@ class FxConverter:
kernel_name,
)
triton_meta = tuner.triton_meta
signature = triton_meta["signature"]
def add_constants_to_call_args(
call_args: Sequence[Any], cfg: Config
) -> tuple[Any, ...]:
"""
Add constant kwargs to the arg list.
"""
# Add args from the proper Triton signature.
new_call_args = []
call_arg_idx = 0
constants = triton_meta["constants"]
for arg_name in signature:
# Config kwargs are tracked separately.
if arg_name in cfg.kwargs:
continue
try:
new_arg = constants[arg_name]
except KeyError:
new_arg = call_args[call_arg_idx]
call_arg_idx += 1
new_call_args.append(new_arg)
# Add Inductor's extra call args to the end.
new_call_args.extend(call_args[call_arg_idx:])
return tuple(new_call_args)
kernel_config = tuner.compile_results[0].config
call_args = add_constants_to_call_args(call_args, kernel_config)
call_args, grid = tuner._interpret_args_grid(call_args, kernel_config)
call_kwargs = dict(zip(tuner.triton_meta["signature"], call_args))
call_kwargs = dict(zip(signature, call_args))
call_kwargs.update(kernel_config.kwargs)
wrapper_grid = [tuple(self._generate_sym_nodes(grid))]