From c839fa4dd2c51a5795fc5f8730bbcd68a1b40d95 Mon Sep 17 00:00:00 2001 From: eellison Date: Tue, 25 Feb 2025 16:15:06 -0800 Subject: [PATCH] [Resubmit] Record input strides at time of tracing, constrain to them for triton fn (#147861) Resubmit of https://github.com/pytorch/pytorch/pull/145448. it lost its changes on rebase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147861 Approved by: https://github.com/zou3519 --- test/inductor/test_triton_kernels.py | 75 +++++++++++++++++++++++++++ torch/_inductor/graph.py | 21 +++++++- torch/_inductor/lowering.py | 25 +++++++++ torch/fx/experimental/proxy_tensor.py | 42 ++++++++++++++- torch/fx/proxy.py | 1 + 5 files changed, 160 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index b95a65a1b6d..60df2bcc369 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -8,6 +8,7 @@ import logging import torch import torch._dynamo.testing import torch._inductor.test_case +import torch.utils._pytree as pytree from torch._dynamo import config as dynamo_config from torch._higher_order_ops.triton_kernel_wrap import ( generate_ttir, @@ -15,6 +16,11 @@ from torch._higher_order_ops.triton_kernel_wrap import ( triton_kernel_wrapper_mutation, ) from torch._inductor import config as inductor_config, metrics +from torch._inductor.pattern_matcher import ( + CallFunctionVarArgs, + PatternMatcherPass, + register_graph_pattern, +) from torch._inductor.utils import run_and_get_code, triton_version_uses_attrs_dict from torch._library import capture_triton from torch.testing import FileCheck @@ -3374,6 +3380,75 @@ class CustomOpTests(torch._inductor.test_case.TestCase): self.assertEqual(status[-1], False) self.assertEqual(z, (x + y) * 2) + @requires_gpu + def test_preserves_strides(self): + import triton + import triton.language as tl + + @triton.jit + def add_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + x = torch.randn(4, 4, 2, 2, device="cuda") + other = torch.randn(4, 4, 2, 2, device="cuda") + + def f(x, other): + y = x.transpose(2, 3).contiguous().transpose(2, 3) + z = y.sin().transpose(2, 3) + grid = (z.numel(),) + out = torch.empty_like(other) + add_kernel[grid](z, other, out, z.numel(), BLOCK_SIZE=16) + return out + + class _CustomPass(PatternMatcherPass): + def __init__(self) -> None: + super().__init__() + + def __call__(self, g: torch.fx.Graph): + self.apply(g) + + g = _CustomPass() + called = False + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.permute), + pass_dict=g, + ) + def _(match, *args, **kwargs): + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return torch.ops.aten.permute(*args, **kwargs).clone( + memory_format=torch.channels_last + ) + + nonlocal called + called = True + match.replace_by_example(decomp, flat_args) + + from torch._inductor import config + + with config.patch( + post_grad_custom_post_pass=g, + ): + f_compile = torch.compile(f) + self.assertEqual(f(x, other), f_compile(x, other)) + self.assertTrue(called) + @requires_gpu @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("autotune", [False, True]) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 44da68f81b5..62ff4bc14d7 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -75,6 +75,7 @@ from .ir import ( TorchBindObject, ) from .lowering import ( + constrain_to_fake_tensors, constrain_to_fx_strides, FALLBACK_ALLOW_LIST, fallback_handler, @@ -232,6 +233,13 @@ def mark_nodes_dislike_padding( ) for cur in reversed(g.nodes): + if isinstance( + cur.target, + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation, + ): + cur.meta["dislike_padding"] = True + continue + op = _get_overload_packet(cur) if not op: continue @@ -1354,8 +1362,9 @@ class GraphLowering(torch.fx.Interpreter): for name in mutated: old_arg = old_kwargs["kwargs"][name] new_arg = new_kwargs["kwargs"][name] - if old_arg is new_args: + if old_arg is new_arg: continue + self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {}) return @@ -1438,7 +1447,15 @@ class GraphLowering(torch.fx.Interpreter): ): old_args = args # type: ignore[possibly-undefined] old_kwargs = kwargs # type: ignore[possibly-undefined] - args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index] + + if arg_kwarg_vals := n.meta.get("arg_kwarg_vals"): + inp_args = arg_kwarg_vals[0] + inp_kwargs = arg_kwarg_vals[1] + args, kwargs = constrain_to_fake_tensors( + args, kwargs, inp_args, inp_kwargs + ) + else: + args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index] result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type] self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined] else: diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index fb7e7c7a49c..40780a18b69 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2413,6 +2413,31 @@ def require_channels_last(_, *args, **kwargs): return args, kwargs +def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs): + def apply_constraint(arg, fake_arg): + if isinstance(arg, ir.IRNode): + meta_stride_expr = [ + s.node.expr if isinstance(s, torch.SymInt) else s + for s in fake_arg.stride() + ] + return ir.ExternKernel.require_exact_strides(arg, meta_stride_expr) + if isinstance(arg, dict): + return { + key: apply_constraint(arg[key], fake_arg[key]) for key in arg.keys() + } + elif isinstance(arg, (tuple, list)): + return type(arg)( + apply_constraint(a, f_a) for (a, f_a) in zip(arg, fake_arg) + ) + return arg + + args = tuple( + apply_constraint(arg, fake_arg) for arg, fake_arg in zip(args, fake_args) + ) + kwargs = {k: apply_constraint(v, fake_kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + def constrain_to_fx_strides(fx_node, *args, **kwargs): def apply_constraint(arg, fx_arg): if isinstance(arg, ir.IRNode): diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 3cf0edb3402..ab5dbc516c1 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -52,7 +52,11 @@ from torch._subclasses.fake_tensor import ( from torch._subclasses.meta_utils import is_sparse_any from torch.fx import GraphModule, Proxy, Tracer from torch.fx.graph_module import _assign_attr -from torch.fx.node import _side_effectful_need_to_be_preserved_pre_dispatch +from torch.fx.node import ( + _side_effectful_need_to_be_preserved_pre_dispatch, + Argument, + Target, +) from torch.fx.passes.shape_prop import _extract_tensor_metadata from torch.nn import Module from torch.overrides import TorchFunctionMode @@ -1087,6 +1091,40 @@ class PythonKeyTracer(Tracer): else: return e + def create_node( + self, + kind: str, + target: Target, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> torch.fx.Node: + node = super().create_node(kind, target, args, kwargs, name, type_expr) # type: ignore[arg-type] + + def map_fn(v: Any) -> Optional[_ExtractValType]: + if not isinstance(v, torch.fx.Node) or "val" not in v.meta: + return None + val = v.meta["val"] + # other subclasses like FunctionalTensor error on `extract_val` + # "Attempting to use FunctionalTensor on its own." just store FakeTensors for now + if isinstance(val, torch.Tensor) and not isinstance(val, FakeTensor): + return None + return extract_val(v.meta["val"]) + + # TODO: opt-in mechanism ? + if isinstance( + target, + ( + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional, + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation, + ), + ): + arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type] + node.meta["arg_kwarg_vals"] = (arg_inp, kwarg_inp) + + return node + def _make_temp_remove_mode_context_manager( mode_ty: type[TorchFunctionMode], @@ -1186,7 +1224,7 @@ def wrap_key( track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]: - return get_proxy_slot(t, tracer, t, lambda x: x.proxy) + return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined] out = f(*tensors) # type:ignore[call-arg] out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out) diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 5da72dca8ac..16124e98ebd 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -114,6 +114,7 @@ _COPY_META_FIELDS = [ "_numeric_debug_handle", # TODO deprecated "custom", "partitioner_tag", + "arg_kwarg_vals", ]