diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index 2895c8991c2..7e6895ccde5 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -789,7 +789,6 @@ class AOTAutogradCacheTests(InductorTestCase): self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) @requires_cuda_and_triton - @requires_triton() @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) @@ -842,6 +841,214 @@ class AOTAutogradCacheTests(InductorTestCase): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) + @requires_cuda_and_triton + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"autograd_cache_allow_custom_autograd_functions": True}) + def test_custom_autograd_function_with_custom_triton_kernel_cache_invalidation( + self, + ): + @triton.jit + def my_jit(x): + arg_0 = tl.load(x) + tl.store(x, arg_0 + 1) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + return y + + class MyAutogradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + y = torch.ops.test.my_triton_op(x) + ctx.save_for_backward(y) + ctx.foo = x.cos() + return y + + @staticmethod + def backward(ctx, grad_output): + result = ctx.saved_tensors[0] + return grad_output * result + ctx.foo * grad_output + + def fn(a): + return MyAutogradFunction.apply(a) + + a = torch.randn(5, device=GPU_TYPE, requires_grad=True) + a2 = a.clone().detach_().requires_grad_(True) + a3 = a.clone().detach_().requires_grad_(True) + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a) + self.assertEqual(fn(a), result) + result.sum().backward() + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + # Clear dynamo and run again. Should be a cache hit. + counters.clear() + self._clear_dynamo_and_codecache() + result = compiled_fn(a2) + self.assertEqual(fn(a2), result) + result.sum().backward() + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) + + # Now modify the source code of my_jit by redefining it + @triton.jit + def my_jit(x): # noqa: F811 + arg_0 = tl.load(x) + tl.store(x, arg_0 + 2) # Changed from +1 to +2 + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + return y + + # Clear dynamo and run again. Should be a cache miss due to modified source code. + counters.clear() + self._clear_dynamo_and_codecache() + compiled_fn = torch.compile(fn, backend="inductor") + + result = compiled_fn(a3) + # Assert that after changing the source code, the cache no longer hits + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(fn(a3), result) + + @requires_cuda_and_triton + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + def test_triton_op_cache_invalidation(self): + from torch._library import capture_triton + + @triton.jit + def my_jit(x): # noqa: F811 + arg_0 = tl.load(x) + tl.store(x, arg_0 + 1) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + capture_triton(my_jit)[1,](y) + return y + + def fn(a): + return torch.ops.test.my_triton_op(a) + + a = torch.randn(5, device=GPU_TYPE) + a2 = a.clone().detach_() + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a) + self.assertEqual(fn(a), result) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + self._clear_dynamo_and_codecache() + + # Redefine the triton op + + @triton.jit + def my_jit(x): # noqa: F811 + arg_0 = tl.load(x) + tl.store(x, arg_0 + 2) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + return y + + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a2) + + # Second run should still miss + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) + + self.assertEqual(fn(a2), result) + + @requires_cuda_and_triton + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @unittest.expectedFailure # Currently ops that call other ops does not properly invalidate cache + def test_triton_op_cache_multiple_ops_invalidation(self): + @triton.jit + def my_jit(x): + arg_0 = tl.load(x) + tl.store(x, arg_0 + 1) + + @triton.jit + def my_jit2(x): + arg_0 = tl.load(x) + tl.store(x, arg_0 + 1) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + torch._library.capture_triton(my_jit2)[1,](y) + return y + + @torch._library.triton_op("test::my_triton_op2", mutates_args=()) + def my_triton_op2(x: torch.Tensor) -> torch.Tensor: + y = x.clone().detach_().requires_grad_(True) + torch.ops.test.my_triton_op(y) + return y + + def fn(a): + return torch.ops.test.my_triton_op2(a) + + a = torch.randn(5, device=GPU_TYPE) + a2 = a.clone().detach_() + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a) + self.assertEqual(fn(a), result) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + self._clear_dynamo_and_codecache() + + # Redefine the triton op + + @triton.jit + def my_jit(x): # noqa: F811 + arg_0 = tl.load(x) + tl.store(x, arg_0 + 2) + + @torch._library.triton_op("test::my_triton_op", mutates_args=()) + def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + torch._library.capture_triton(my_jit)[1,](y) + torch._library.capture_triton(my_jit2)[1,](y) + return y + + @torch._library.triton_op("test::my_triton_op2", mutates_args=()) + def my_triton_op2(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + y = x.clone().detach_().requires_grad_(True) + torch.ops.test.my_triton_op(y) + return y + + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(a2) + + # Second run should still miss + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) + + self.assertEqual(fn(a2), result) + @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch({"fx_graph_cache": True}) @functorch_config.patch({"enable_autograd_cache": True}) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 7217a9c9b39..248c3a0ae67 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -302,6 +302,42 @@ class AOTAutogradCacheDetails(FxGraphHashDetails): a safe and stable cache key for AOTAutograd. """ + def get_triton_source_codes_from_gm( + self, + gm: torch.fx.GraphModule, + ): + triton_kernels = [] + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if isinstance(node.target, torch._ops.OpOverloadPacket): + attrs = node.target._dir + for attr in attrs: + if custom_op := getattr(node.target, attr, None): + kernels = torch._library.triton.get_triton_kernels_for_op( + custom_op._name + ) + triton_kernels.extend(kernels) + elif isinstance(node.target, torch._ops.OpOverload): + kernels = torch._library.triton.get_triton_kernels_for_op( + node.target._name + ) + triton_kernels.extend(kernels) + + triton_kernel_source_codes = [] + from torch._inductor.codegen.wrapper import ( + user_defined_triton_kernel_transitive_closure_source_code, + ) + + for kernel in triton_kernels: + source_codes = user_defined_triton_kernel_transitive_closure_source_code( + kernel + ) + triton_kernel_source_codes.append(source_codes) + + return triton_kernel_source_codes + def __init__( self, gm: torch.fx.GraphModule, @@ -319,6 +355,7 @@ class AOTAutogradCacheDetails(FxGraphHashDetails): [], [], ) + self.triton_kernel_source_codes = self.get_triton_source_codes_from_gm(gm) if hasattr(gm, "saved_tensors_hooks_pack_0"): diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index bd8acb2789e..251cdefe0f0 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -210,6 +210,7 @@ class CustomOpDef: self._lib = get_library_allowing_overwrite(self._namespace, self._name) self._register_to_dispatcher(self._tags) self._disabled_kernel: set = set() + self._used_triton_kernels: list[Any] = list() OPDEFS[self._qualname] = self @property diff --git a/torch/_library/triton.py b/torch/_library/triton.py index 17d02a99456..741b341f7e2 100644 --- a/torch/_library/triton.py +++ b/torch/_library/triton.py @@ -1,4 +1,6 @@ +import ast import contextlib +import inspect import threading from collections.abc import Generator, Iterable from typing import Any, Callable, Optional, Union @@ -9,6 +11,79 @@ from .custom_ops import custom_op, CustomOpDef from .infer_schema import infer_schema +triton_ops_to_kernels: dict[str, list[object]] = {} + + +def get_triton_kernels_for_op(name: str) -> list[object]: + return triton_ops_to_kernels.get(name, []) + + +def get_inner_triton_kernels(fn: Callable[..., Any]) -> list[object]: + """ + Inspect the source of an arbitrary callable passed to torch._library.triton_op, + and grab all of the triton kernels that are wrapped inside of it. + + TODO: This check is best effort. It does *not* handle the case where the triton + kernel is hidden behind recursive function calls. + """ + + def find_triton_kernels(fn: Callable[..., Any]) -> list[object]: + try: + source = inspect.getsource(fn) + except (OSError, TypeError): + return [] # Source code not available + + from torch._inductor.utils import IndentedBuffer + + buffer = IndentedBuffer() + buffer.splice(source, strip=True) + tree = ast.parse(buffer.getrawvalue()) + + # Visitor to collect function calls and triton kernels + class Visitor(ast.NodeVisitor): + def __init__(self) -> None: + self.triton_kernels: list[Any] = [] + + def visit_Call(self, node: ast.Call) -> None: + triton_func_names = ("capture_triton", "wrap_triton") + if isinstance(node.func, ast.Attribute): + attr = node.func + if ( + isinstance(attr.value, ast.Attribute) + and isinstance(attr.value.value, ast.Name) + and attr.value.value.id == "torch" + and attr.value.attr == "_library" + and attr.attr in triton_func_names + ): + if node.args and isinstance(node.args[0], ast.Name): + self.triton_kernels.append(node.args[0].id) + + # Catch capture_triton, wrap_triton that's been + # imported directly + elif isinstance(node.func, ast.Name): + if node.func.id in triton_func_names: + if node.args and isinstance(node.args[0], ast.Name): + self.triton_kernels.append(node.args[0].id) + + self.generic_visit(node) + + collector = Visitor() + collector.visit(tree) + closure_vars = inspect.getclosurevars(fn) + resolved = [] + # First, resolve triton kernel names + for name in collector.triton_kernels: + if name in closure_vars.nonlocals: + resolved.append(closure_vars.nonlocals[name]) + elif name in closure_vars.globals: + resolved.append(closure_vars.globals[name]) + elif name in closure_vars.builtins: + resolved.append(closure_vars.builtins[name]) + return resolved + + return find_triton_kernels(fn) + + @exposed_in("torch.library") def triton_op( name: str, @@ -175,6 +250,8 @@ def triton_op( with mode: return fn(*args, **kwargs) + triton_kernels = get_inner_triton_kernels(fn) + triton_ops_to_kernels[name] = triton_kernels result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) return result