mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Account for triton kernel source code hidden in custom ops properly in AOTAutogradCache (#160120)
This PR fixes a bug where user defined triton kernels hidden behind `triton_op` do not register source code changes. If a user *only* changes a triton kernel source_code, because triton kernels are hidden under the custom op, dynamo hasn't traced into them yet. This means at AOTAutograd time, we don't know the list of triton kernels that are defined by custom ops. This is an initial fix for the issue by parsing the AST of the custom op looking for triton kernels. This won't catch more degenerate cases if the custom op calls other custom ops/functions that then call triton kernels, and then the toplevel compiled graph doesn't know about it. To handle that, we'd have to trace through the custom op at dynamo time. This should handle 99% of cases, though. I added an expectedFailure test to show the limitation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160120 Approved by: https://github.com/zou3519
This commit is contained in:
parent
a288b15ea9
commit
9708fcf92d
|
|
@ -789,7 +789,6 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@requires_triton()
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
|
|
@ -842,6 +841,214 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@functorch_config.patch({"autograd_cache_allow_custom_autograd_functions": True})
|
||||
def test_custom_autograd_function_with_custom_triton_kernel_cache_invalidation(
|
||||
self,
|
||||
):
|
||||
@triton.jit
|
||||
def my_jit(x):
|
||||
arg_0 = tl.load(x)
|
||||
tl.store(x, arg_0 + 1)
|
||||
|
||||
@torch._library.triton_op("test::my_triton_op", mutates_args=())
|
||||
def my_triton_op(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x.clone().detach_().requires_grad_(True)
|
||||
torch._library.capture_triton(my_jit)[1,](y)
|
||||
return y
|
||||
|
||||
class MyAutogradFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
y = torch.ops.test.my_triton_op(x)
|
||||
ctx.save_for_backward(y)
|
||||
ctx.foo = x.cos()
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
result = ctx.saved_tensors[0]
|
||||
return grad_output * result + ctx.foo * grad_output
|
||||
|
||||
def fn(a):
|
||||
return MyAutogradFunction.apply(a)
|
||||
|
||||
a = torch.randn(5, device=GPU_TYPE, requires_grad=True)
|
||||
a2 = a.clone().detach_().requires_grad_(True)
|
||||
a3 = a.clone().detach_().requires_grad_(True)
|
||||
compiled_fn = torch.compile(fn, backend="inductor")
|
||||
result = compiled_fn(a)
|
||||
self.assertEqual(fn(a), result)
|
||||
result.sum().backward()
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
# Clear dynamo and run again. Should be a cache hit.
|
||||
counters.clear()
|
||||
self._clear_dynamo_and_codecache()
|
||||
result = compiled_fn(a2)
|
||||
self.assertEqual(fn(a2), result)
|
||||
result.sum().backward()
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
|
||||
|
||||
# Now modify the source code of my_jit by redefining it
|
||||
@triton.jit
|
||||
def my_jit(x): # noqa: F811
|
||||
arg_0 = tl.load(x)
|
||||
tl.store(x, arg_0 + 2) # Changed from +1 to +2
|
||||
|
||||
@torch._library.triton_op("test::my_triton_op", mutates_args=())
|
||||
def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811
|
||||
y = x.clone().detach_().requires_grad_(True)
|
||||
torch._library.capture_triton(my_jit)[1,](y)
|
||||
return y
|
||||
|
||||
# Clear dynamo and run again. Should be a cache miss due to modified source code.
|
||||
counters.clear()
|
||||
self._clear_dynamo_and_codecache()
|
||||
compiled_fn = torch.compile(fn, backend="inductor")
|
||||
|
||||
result = compiled_fn(a3)
|
||||
# Assert that after changing the source code, the cache no longer hits
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(fn(a3), result)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_triton_op_cache_invalidation(self):
|
||||
from torch._library import capture_triton
|
||||
|
||||
@triton.jit
|
||||
def my_jit(x): # noqa: F811
|
||||
arg_0 = tl.load(x)
|
||||
tl.store(x, arg_0 + 1)
|
||||
|
||||
@torch._library.triton_op("test::my_triton_op", mutates_args=())
|
||||
def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811
|
||||
y = x.clone().detach_().requires_grad_(True)
|
||||
capture_triton(my_jit)[1,](y)
|
||||
return y
|
||||
|
||||
def fn(a):
|
||||
return torch.ops.test.my_triton_op(a)
|
||||
|
||||
a = torch.randn(5, device=GPU_TYPE)
|
||||
a2 = a.clone().detach_()
|
||||
compiled_fn = torch.compile(fn, backend="inductor")
|
||||
result = compiled_fn(a)
|
||||
self.assertEqual(fn(a), result)
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
self._clear_dynamo_and_codecache()
|
||||
|
||||
# Redefine the triton op
|
||||
|
||||
@triton.jit
|
||||
def my_jit(x): # noqa: F811
|
||||
arg_0 = tl.load(x)
|
||||
tl.store(x, arg_0 + 2)
|
||||
|
||||
@torch._library.triton_op("test::my_triton_op", mutates_args=())
|
||||
def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811
|
||||
y = x.clone().detach_().requires_grad_(True)
|
||||
torch._library.capture_triton(my_jit)[1,](y)
|
||||
return y
|
||||
|
||||
compiled_fn = torch.compile(fn, backend="inductor")
|
||||
result = compiled_fn(a2)
|
||||
|
||||
# Second run should still miss
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)
|
||||
|
||||
self.assertEqual(fn(a2), result)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@unittest.expectedFailure # Currently ops that call other ops does not properly invalidate cache
|
||||
def test_triton_op_cache_multiple_ops_invalidation(self):
|
||||
@triton.jit
|
||||
def my_jit(x):
|
||||
arg_0 = tl.load(x)
|
||||
tl.store(x, arg_0 + 1)
|
||||
|
||||
@triton.jit
|
||||
def my_jit2(x):
|
||||
arg_0 = tl.load(x)
|
||||
tl.store(x, arg_0 + 1)
|
||||
|
||||
@torch._library.triton_op("test::my_triton_op", mutates_args=())
|
||||
def my_triton_op(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x.clone().detach_().requires_grad_(True)
|
||||
torch._library.capture_triton(my_jit)[1,](y)
|
||||
torch._library.capture_triton(my_jit2)[1,](y)
|
||||
return y
|
||||
|
||||
@torch._library.triton_op("test::my_triton_op2", mutates_args=())
|
||||
def my_triton_op2(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x.clone().detach_().requires_grad_(True)
|
||||
torch.ops.test.my_triton_op(y)
|
||||
return y
|
||||
|
||||
def fn(a):
|
||||
return torch.ops.test.my_triton_op2(a)
|
||||
|
||||
a = torch.randn(5, device=GPU_TYPE)
|
||||
a2 = a.clone().detach_()
|
||||
compiled_fn = torch.compile(fn, backend="inductor")
|
||||
result = compiled_fn(a)
|
||||
self.assertEqual(fn(a), result)
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
self._clear_dynamo_and_codecache()
|
||||
|
||||
# Redefine the triton op
|
||||
|
||||
@triton.jit
|
||||
def my_jit(x): # noqa: F811
|
||||
arg_0 = tl.load(x)
|
||||
tl.store(x, arg_0 + 2)
|
||||
|
||||
@torch._library.triton_op("test::my_triton_op", mutates_args=())
|
||||
def my_triton_op(x: torch.Tensor) -> torch.Tensor: # noqa: F811
|
||||
y = x.clone().detach_().requires_grad_(True)
|
||||
torch._library.capture_triton(my_jit)[1,](y)
|
||||
torch._library.capture_triton(my_jit2)[1,](y)
|
||||
return y
|
||||
|
||||
@torch._library.triton_op("test::my_triton_op2", mutates_args=())
|
||||
def my_triton_op2(x: torch.Tensor) -> torch.Tensor: # noqa: F811
|
||||
y = x.clone().detach_().requires_grad_(True)
|
||||
torch.ops.test.my_triton_op(y)
|
||||
return y
|
||||
|
||||
compiled_fn = torch.compile(fn, backend="inductor")
|
||||
result = compiled_fn(a2)
|
||||
|
||||
# Second run should still miss
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)
|
||||
|
||||
self.assertEqual(fn(a2), result)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch({"fx_graph_cache": True})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
|
|
|
|||
|
|
@ -302,6 +302,42 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
|
|||
a safe and stable cache key for AOTAutograd.
|
||||
"""
|
||||
|
||||
def get_triton_source_codes_from_gm(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
):
|
||||
triton_kernels = []
|
||||
for module in gm.modules():
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
continue
|
||||
for node in module.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverloadPacket):
|
||||
attrs = node.target._dir
|
||||
for attr in attrs:
|
||||
if custom_op := getattr(node.target, attr, None):
|
||||
kernels = torch._library.triton.get_triton_kernels_for_op(
|
||||
custom_op._name
|
||||
)
|
||||
triton_kernels.extend(kernels)
|
||||
elif isinstance(node.target, torch._ops.OpOverload):
|
||||
kernels = torch._library.triton.get_triton_kernels_for_op(
|
||||
node.target._name
|
||||
)
|
||||
triton_kernels.extend(kernels)
|
||||
|
||||
triton_kernel_source_codes = []
|
||||
from torch._inductor.codegen.wrapper import (
|
||||
user_defined_triton_kernel_transitive_closure_source_code,
|
||||
)
|
||||
|
||||
for kernel in triton_kernels:
|
||||
source_codes = user_defined_triton_kernel_transitive_closure_source_code(
|
||||
kernel
|
||||
)
|
||||
triton_kernel_source_codes.append(source_codes)
|
||||
|
||||
return triton_kernel_source_codes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
|
|
@ -319,6 +355,7 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
|
|||
[],
|
||||
[],
|
||||
)
|
||||
self.triton_kernel_source_codes = self.get_triton_source_codes_from_gm(gm)
|
||||
|
||||
if hasattr(gm, "saved_tensors_hooks_pack_0"):
|
||||
|
||||
|
|
|
|||
|
|
@ -210,6 +210,7 @@ class CustomOpDef:
|
|||
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
|
||||
self._register_to_dispatcher(self._tags)
|
||||
self._disabled_kernel: set = set()
|
||||
self._used_triton_kernels: list[Any] = list()
|
||||
OPDEFS[self._qualname] = self
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import ast
|
||||
import contextlib
|
||||
import inspect
|
||||
import threading
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
|
@ -9,6 +11,79 @@ from .custom_ops import custom_op, CustomOpDef
|
|||
from .infer_schema import infer_schema
|
||||
|
||||
|
||||
triton_ops_to_kernels: dict[str, list[object]] = {}
|
||||
|
||||
|
||||
def get_triton_kernels_for_op(name: str) -> list[object]:
|
||||
return triton_ops_to_kernels.get(name, [])
|
||||
|
||||
|
||||
def get_inner_triton_kernels(fn: Callable[..., Any]) -> list[object]:
|
||||
"""
|
||||
Inspect the source of an arbitrary callable passed to torch._library.triton_op,
|
||||
and grab all of the triton kernels that are wrapped inside of it.
|
||||
|
||||
TODO: This check is best effort. It does *not* handle the case where the triton
|
||||
kernel is hidden behind recursive function calls.
|
||||
"""
|
||||
|
||||
def find_triton_kernels(fn: Callable[..., Any]) -> list[object]:
|
||||
try:
|
||||
source = inspect.getsource(fn)
|
||||
except (OSError, TypeError):
|
||||
return [] # Source code not available
|
||||
|
||||
from torch._inductor.utils import IndentedBuffer
|
||||
|
||||
buffer = IndentedBuffer()
|
||||
buffer.splice(source, strip=True)
|
||||
tree = ast.parse(buffer.getrawvalue())
|
||||
|
||||
# Visitor to collect function calls and triton kernels
|
||||
class Visitor(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
self.triton_kernels: list[Any] = []
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> None:
|
||||
triton_func_names = ("capture_triton", "wrap_triton")
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
attr = node.func
|
||||
if (
|
||||
isinstance(attr.value, ast.Attribute)
|
||||
and isinstance(attr.value.value, ast.Name)
|
||||
and attr.value.value.id == "torch"
|
||||
and attr.value.attr == "_library"
|
||||
and attr.attr in triton_func_names
|
||||
):
|
||||
if node.args and isinstance(node.args[0], ast.Name):
|
||||
self.triton_kernels.append(node.args[0].id)
|
||||
|
||||
# Catch capture_triton, wrap_triton that's been
|
||||
# imported directly
|
||||
elif isinstance(node.func, ast.Name):
|
||||
if node.func.id in triton_func_names:
|
||||
if node.args and isinstance(node.args[0], ast.Name):
|
||||
self.triton_kernels.append(node.args[0].id)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
collector = Visitor()
|
||||
collector.visit(tree)
|
||||
closure_vars = inspect.getclosurevars(fn)
|
||||
resolved = []
|
||||
# First, resolve triton kernel names
|
||||
for name in collector.triton_kernels:
|
||||
if name in closure_vars.nonlocals:
|
||||
resolved.append(closure_vars.nonlocals[name])
|
||||
elif name in closure_vars.globals:
|
||||
resolved.append(closure_vars.globals[name])
|
||||
elif name in closure_vars.builtins:
|
||||
resolved.append(closure_vars.builtins[name])
|
||||
return resolved
|
||||
|
||||
return find_triton_kernels(fn)
|
||||
|
||||
|
||||
@exposed_in("torch.library")
|
||||
def triton_op(
|
||||
name: str,
|
||||
|
|
@ -175,6 +250,8 @@ def triton_op(
|
|||
with mode:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
triton_kernels = get_inner_triton_kernels(fn)
|
||||
triton_ops_to_kernels[name] = triton_kernels
|
||||
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
|
||||
return result
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user