mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Inductor-FX] Support custom triton kernels (#161474)
# Feature Add support for custom Triton kernels to the FX backend. This turned out not to require any new features, except for a minor change to handle `tl.constexpr` arguments which are not part of the autotuning config. # Caveat This may not cover every possible case. For example, we might need more features for autotuning custom Triton code. This PR entirely skips the [custom codegen ](https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/triton_kernel_wrap.py#L1034-L1039) for user-defined grid functions, but there may be edge cases requiring this logic. However, this PR seems to do a reasonable job as many of the grids end up being written into Inductor/Triton metadata and don't require special codegen. As a follow up, I'm planning to test this against all of AOTI's custom Triton kernel tests. # Test plan Added a CI test using a custom Triton kernel. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161474 Approved by: https://github.com/angelayi
This commit is contained in:
parent
dbc903a94a
commit
9de9d25f8d
|
|
@ -35,6 +35,11 @@ from torch.testing._internal.inductor_utils import (
|
|||
)
|
||||
|
||||
|
||||
if HAS_GPU:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@requires_gpu()
|
||||
@config.patch(
|
||||
compile_threads=1,
|
||||
|
|
@ -544,6 +549,37 @@ class FxirTestCase(InductorTestCase):
|
|||
if use_dynamic_shapes:
|
||||
self.assertEqual(type(shape[0]), torch.fx.Node)
|
||||
|
||||
def test_custom_triton(self):
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
in_ptr0,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output = torch.empty_like(x)
|
||||
n_elements = output.numel()
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
|
||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
|
||||
return output
|
||||
|
||||
args = [torch.randn(32, device=self.device) for _ in range(2)]
|
||||
self._compile_and_check(add, args)
|
||||
|
||||
def test_output_slice_view(self):
|
||||
"""
|
||||
Test when the output is a view of the input.
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
|
|||
from torch.utils._sympy.reference import OptimizedPythonReferenceAnalysis
|
||||
|
||||
from .. import config, ir
|
||||
from ..runtime.triton_compat import Config
|
||||
from ..utils import LineContext
|
||||
from .common import (
|
||||
CodegenSymbol,
|
||||
|
|
@ -700,9 +701,40 @@ class FxConverter:
|
|||
kernel_name,
|
||||
)
|
||||
|
||||
triton_meta = tuner.triton_meta
|
||||
signature = triton_meta["signature"]
|
||||
|
||||
def add_constants_to_call_args(
|
||||
call_args: Sequence[Any], cfg: Config
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
Add constant kwargs to the arg list.
|
||||
"""
|
||||
# Add args from the proper Triton signature.
|
||||
new_call_args = []
|
||||
call_arg_idx = 0
|
||||
constants = triton_meta["constants"]
|
||||
for arg_name in signature:
|
||||
# Config kwargs are tracked separately.
|
||||
if arg_name in cfg.kwargs:
|
||||
continue
|
||||
|
||||
try:
|
||||
new_arg = constants[arg_name]
|
||||
except KeyError:
|
||||
new_arg = call_args[call_arg_idx]
|
||||
call_arg_idx += 1
|
||||
new_call_args.append(new_arg)
|
||||
|
||||
# Add Inductor's extra call args to the end.
|
||||
new_call_args.extend(call_args[call_arg_idx:])
|
||||
|
||||
return tuple(new_call_args)
|
||||
|
||||
kernel_config = tuner.compile_results[0].config
|
||||
call_args = add_constants_to_call_args(call_args, kernel_config)
|
||||
call_args, grid = tuner._interpret_args_grid(call_args, kernel_config)
|
||||
call_kwargs = dict(zip(tuner.triton_meta["signature"], call_args))
|
||||
call_kwargs = dict(zip(signature, call_args))
|
||||
call_kwargs.update(kernel_config.kwargs)
|
||||
|
||||
wrapper_grid = [tuple(self._generate_sym_nodes(grid))]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user