Account for triton kernel source code hidden in custom ops properly in AOTAutogradCache (#160120)

This PR fixes a bug where user defined triton kernels hidden behind `triton_op` do not register source code changes. If a user *only* changes a triton kernel source_code, because triton kernels are hidden under the custom op, dynamo hasn't traced into them yet.

This means at AOTAutograd time, we don't know the list of triton kernels that are defined by custom ops. This is an initial fix for the issue by parsing the AST of the custom op looking for triton kernels. This won't catch more degenerate cases if the custom op calls other custom ops/functions that then call triton kernels, and then the toplevel compiled graph doesn't know about it. To handle that, we'd have to trace through the custom op at dynamo time.

This should handle 99% of cases, though. I added an expectedFailure test to show the limitation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160120
Approved by: https://github.com/zou3519
This commit is contained in:
James Wu 2025-08-10 15:38:35 -07:00 committed by PyTorch MergeBot
parent a288b15ea9
commit 9708fcf92d
4 changed files with 323 additions and 1 deletions

View File

@ -789,7 +789,6 @@ class AOTAutogradCacheTests(InductorTestCase):
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
@requires_cuda_and_triton
@requires_triton()
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
@ -842,6 +841,214 @@ class AOTAutogradCacheTests(InductorTestCase):
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
@requires_cuda_and_triton
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch({"autograd_cache_allow_custom_autograd_functions": True})
def test_custom_autograd_function_with_custom_triton_kernel_cache_invalidation(
self,
):
@triton.jit
def my_jit(x):
arg_0 = tl.load(x)
tl.store(x, arg_0 + 1)
@torch._library.triton_op("test::my_triton_op", mutates_args=())
def my_triton_op(x: torch.Tensor) -> torch.Tensor:
y = x.clone().detach_().requires_grad_(True)
torch._library.capture_triton(my_jit)[1,](y)
return y
class MyAutogradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = torch.ops.test.my_triton_op(x)
ctx.save_for_backward(y)
ctx.foo = x.cos()
return y
@staticmethod
def backward(ctx, grad_output):
result = ctx.saved_tensors[0]
return grad_output * result + ctx.foo * grad_output
def fn(a):
return MyAutogradFunction.apply(a)
a = torch.randn(5, device=GPU_TYPE, requires_grad=True)
a2 = a.clone().detach_().requires_grad_(True)
a3 = a.clone().detach_().requires_grad_(True)
compiled_fn = torch.compile(fn, backend="inductor")
result = compiled_fn(a)
self.assertEqual(fn(a), result)
result.sum().backward()
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
# Clear dynamo and run again. Should be a cache hit.
counters.clear()
self._clear_dynamo_and_codecache()
result = compiled_fn(a2)
self.assertEqual(fn(a2), result)
result.sum().backward()
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
# Now modify the source code of my_jit by redefining it
@triton.jit
def my_jit(x): # noqa: F811
arg_0 = tl.load(x)
tl.store(x, arg_0 + 2) # Changed from +1 to +2
@torch._library.triton_op("test::my_triton_op", mutates_args=())
def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811
y = x.clone().detach_().requires_grad_(True)
torch._library.capture_triton(my_jit)[1,](y)
return y
# Clear dynamo and run again. Should be a cache miss due to modified source code.
counters.clear()
self._clear_dynamo_and_codecache()
compiled_fn = torch.compile(fn, backend="inductor")
result = compiled_fn(a3)
# Assert that after changing the source code, the cache no longer hits
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(fn(a3), result)
@requires_cuda_and_triton
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
def test_triton_op_cache_invalidation(self):
from torch._library import capture_triton
@triton.jit
def my_jit(x): # noqa: F811
arg_0 = tl.load(x)
tl.store(x, arg_0 + 1)
@torch._library.triton_op("test::my_triton_op", mutates_args=())
def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811
y = x.clone().detach_().requires_grad_(True)
capture_triton(my_jit)[1,](y)
return y
def fn(a):
return torch.ops.test.my_triton_op(a)
a = torch.randn(5, device=GPU_TYPE)
a2 = a.clone().detach_()
compiled_fn = torch.compile(fn, backend="inductor")
result = compiled_fn(a)
self.assertEqual(fn(a), result)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
self._clear_dynamo_and_codecache()
# Redefine the triton op
@triton.jit
def my_jit(x): # noqa: F811
arg_0 = tl.load(x)
tl.store(x, arg_0 + 2)
@torch._library.triton_op("test::my_triton_op", mutates_args=())
def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811
y = x.clone().detach_().requires_grad_(True)
torch._library.capture_triton(my_jit)[1,](y)
return y
compiled_fn = torch.compile(fn, backend="inductor")
result = compiled_fn(a2)
# Second run should still miss
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)
self.assertEqual(fn(a2), result)
@requires_cuda_and_triton
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch({"enable_autograd_cache": True})
@unittest.expectedFailure # Currently ops that call other ops does not properly invalidate cache
def test_triton_op_cache_multiple_ops_invalidation(self):
@triton.jit
def my_jit(x):
arg_0 = tl.load(x)
tl.store(x, arg_0 + 1)
@triton.jit
def my_jit2(x):
arg_0 = tl.load(x)
tl.store(x, arg_0 + 1)
@torch._library.triton_op("test::my_triton_op", mutates_args=())
def my_triton_op(x: torch.Tensor) -> torch.Tensor:
y = x.clone().detach_().requires_grad_(True)
torch._library.capture_triton(my_jit)[1,](y)
torch._library.capture_triton(my_jit2)[1,](y)
return y
@torch._library.triton_op("test::my_triton_op2", mutates_args=())
def my_triton_op2(x: torch.Tensor) -> torch.Tensor:
y = x.clone().detach_().requires_grad_(True)
torch.ops.test.my_triton_op(y)
return y
def fn(a):
return torch.ops.test.my_triton_op2(a)
a = torch.randn(5, device=GPU_TYPE)
a2 = a.clone().detach_()
compiled_fn = torch.compile(fn, backend="inductor")
result = compiled_fn(a)
self.assertEqual(fn(a), result)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
self._clear_dynamo_and_codecache()
# Redefine the triton op
@triton.jit
def my_jit(x): # noqa: F811
arg_0 = tl.load(x)
tl.store(x, arg_0 + 2)
@torch._library.triton_op("test::my_triton_op", mutates_args=())
def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811
y = x.clone().detach_().requires_grad_(True)
torch._library.capture_triton(my_jit)[1,](y)
torch._library.capture_triton(my_jit2)[1,](y)
return y
@torch._library.triton_op("test::my_triton_op2", mutates_args=())
def my_triton_op2(x: torch.Tensor) -> torch.Tensor: # noqa: F811
y = x.clone().detach_().requires_grad_(True)
torch.ops.test.my_triton_op(y)
return y
compiled_fn = torch.compile(fn, backend="inductor")
result = compiled_fn(a2)
# Second run should still miss
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)
self.assertEqual(fn(a2), result)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch({"fx_graph_cache": True})
@functorch_config.patch({"enable_autograd_cache": True})

View File

@ -302,6 +302,42 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
a safe and stable cache key for AOTAutograd.
"""
def get_triton_source_codes_from_gm(
self,
gm: torch.fx.GraphModule,
):
triton_kernels = []
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if isinstance(node.target, torch._ops.OpOverloadPacket):
attrs = node.target._dir
for attr in attrs:
if custom_op := getattr(node.target, attr, None):
kernels = torch._library.triton.get_triton_kernels_for_op(
custom_op._name
)
triton_kernels.extend(kernels)
elif isinstance(node.target, torch._ops.OpOverload):
kernels = torch._library.triton.get_triton_kernels_for_op(
node.target._name
)
triton_kernels.extend(kernels)
triton_kernel_source_codes = []
from torch._inductor.codegen.wrapper import (
user_defined_triton_kernel_transitive_closure_source_code,
)
for kernel in triton_kernels:
source_codes = user_defined_triton_kernel_transitive_closure_source_code(
kernel
)
triton_kernel_source_codes.append(source_codes)
return triton_kernel_source_codes
def __init__(
self,
gm: torch.fx.GraphModule,
@ -319,6 +355,7 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
[],
[],
)
self.triton_kernel_source_codes = self.get_triton_source_codes_from_gm(gm)
if hasattr(gm, "saved_tensors_hooks_pack_0"):

View File

@ -210,6 +210,7 @@ class CustomOpDef:
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
self._register_to_dispatcher(self._tags)
self._disabled_kernel: set = set()
self._used_triton_kernels: list[Any] = list()
OPDEFS[self._qualname] = self
@property

View File

@ -1,4 +1,6 @@
import ast
import contextlib
import inspect
import threading
from collections.abc import Generator, Iterable
from typing import Any, Callable, Optional, Union
@ -9,6 +11,79 @@ from .custom_ops import custom_op, CustomOpDef
from .infer_schema import infer_schema
triton_ops_to_kernels: dict[str, list[object]] = {}
def get_triton_kernels_for_op(name: str) -> list[object]:
return triton_ops_to_kernels.get(name, [])
def get_inner_triton_kernels(fn: Callable[..., Any]) -> list[object]:
"""
Inspect the source of an arbitrary callable passed to torch._library.triton_op,
and grab all of the triton kernels that are wrapped inside of it.
TODO: This check is best effort. It does *not* handle the case where the triton
kernel is hidden behind recursive function calls.
"""
def find_triton_kernels(fn: Callable[..., Any]) -> list[object]:
try:
source = inspect.getsource(fn)
except (OSError, TypeError):
return [] # Source code not available
from torch._inductor.utils import IndentedBuffer
buffer = IndentedBuffer()
buffer.splice(source, strip=True)
tree = ast.parse(buffer.getrawvalue())
# Visitor to collect function calls and triton kernels
class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
self.triton_kernels: list[Any] = []
def visit_Call(self, node: ast.Call) -> None:
triton_func_names = ("capture_triton", "wrap_triton")
if isinstance(node.func, ast.Attribute):
attr = node.func
if (
isinstance(attr.value, ast.Attribute)
and isinstance(attr.value.value, ast.Name)
and attr.value.value.id == "torch"
and attr.value.attr == "_library"
and attr.attr in triton_func_names
):
if node.args and isinstance(node.args[0], ast.Name):
self.triton_kernels.append(node.args[0].id)
# Catch capture_triton, wrap_triton that's been
# imported directly
elif isinstance(node.func, ast.Name):
if node.func.id in triton_func_names:
if node.args and isinstance(node.args[0], ast.Name):
self.triton_kernels.append(node.args[0].id)
self.generic_visit(node)
collector = Visitor()
collector.visit(tree)
closure_vars = inspect.getclosurevars(fn)
resolved = []
# First, resolve triton kernel names
for name in collector.triton_kernels:
if name in closure_vars.nonlocals:
resolved.append(closure_vars.nonlocals[name])
elif name in closure_vars.globals:
resolved.append(closure_vars.globals[name])
elif name in closure_vars.builtins:
resolved.append(closure_vars.builtins[name])
return resolved
return find_triton_kernels(fn)
@exposed_in("torch.library")
def triton_op(
name: str,
@ -175,6 +250,8 @@ def triton_op(
with mode:
return fn(*args, **kwargs)
triton_kernels = get_inner_triton_kernels(fn)
triton_ops_to_kernels[name] = triton_kernels
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
return result