[Resubmit] Record input strides at time of tracing, constrain to them for triton fn (#147861)

Resubmit of https://github.com/pytorch/pytorch/pull/145448. it lost its changes on rebase.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147861
Approved by: https://github.com/zou3519
This commit is contained in:
eellison 2025-02-25 16:15:06 -08:00 committed by PyTorch MergeBot
parent ba25e26baa
commit c839fa4dd2
5 changed files with 160 additions and 4 deletions

View File

@ -8,6 +8,7 @@ import logging
import torch
import torch._dynamo.testing
import torch._inductor.test_case
import torch.utils._pytree as pytree
from torch._dynamo import config as dynamo_config
from torch._higher_order_ops.triton_kernel_wrap import (
generate_ttir,
@ -15,6 +16,11 @@ from torch._higher_order_ops.triton_kernel_wrap import (
triton_kernel_wrapper_mutation,
)
from torch._inductor import config as inductor_config, metrics
from torch._inductor.pattern_matcher import (
CallFunctionVarArgs,
PatternMatcherPass,
register_graph_pattern,
)
from torch._inductor.utils import run_and_get_code, triton_version_uses_attrs_dict
from torch._library import capture_triton
from torch.testing import FileCheck
@ -3374,6 +3380,75 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
self.assertEqual(status[-1], False)
self.assertEqual(z, (x + y) * 2)
@requires_gpu
def test_preserves_strides(self):
import triton
import triton.language as tl
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
x = torch.randn(4, 4, 2, 2, device="cuda")
other = torch.randn(4, 4, 2, 2, device="cuda")
def f(x, other):
y = x.transpose(2, 3).contiguous().transpose(2, 3)
z = y.sin().transpose(2, 3)
grid = (z.numel(),)
out = torch.empty_like(other)
add_kernel[grid](z, other, out, z.numel(), BLOCK_SIZE=16)
return out
class _CustomPass(PatternMatcherPass):
def __init__(self) -> None:
super().__init__()
def __call__(self, g: torch.fx.Graph):
self.apply(g)
g = _CustomPass()
called = False
@register_graph_pattern(
CallFunctionVarArgs(torch.ops.aten.permute),
pass_dict=g,
)
def _(match, *args, **kwargs):
flat_args, spec = pytree.tree_flatten((args, kwargs))
def decomp(*flat_args):
args, kwargs = pytree.tree_unflatten(flat_args, spec)
return torch.ops.aten.permute(*args, **kwargs).clone(
memory_format=torch.channels_last
)
nonlocal called
called = True
match.replace_by_example(decomp, flat_args)
from torch._inductor import config
with config.patch(
post_grad_custom_post_pass=g,
):
f_compile = torch.compile(f)
self.assertEqual(f(x, other), f_compile(x, other))
self.assertTrue(called)
@requires_gpu
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("autotune", [False, True])

View File

@ -75,6 +75,7 @@ from .ir import (
TorchBindObject,
)
from .lowering import (
constrain_to_fake_tensors,
constrain_to_fx_strides,
FALLBACK_ALLOW_LIST,
fallback_handler,
@ -232,6 +233,13 @@ def mark_nodes_dislike_padding(
)
for cur in reversed(g.nodes):
if isinstance(
cur.target,
torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation,
):
cur.meta["dislike_padding"] = True
continue
op = _get_overload_packet(cur)
if not op:
continue
@ -1354,8 +1362,9 @@ class GraphLowering(torch.fx.Interpreter):
for name in mutated:
old_arg = old_kwargs["kwargs"][name]
new_arg = new_kwargs["kwargs"][name]
if old_arg is new_args:
if old_arg is new_arg:
continue
self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {})
return
@ -1438,7 +1447,15 @@ class GraphLowering(torch.fx.Interpreter):
):
old_args = args # type: ignore[possibly-undefined]
old_kwargs = kwargs # type: ignore[possibly-undefined]
args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index]
if arg_kwarg_vals := n.meta.get("arg_kwarg_vals"):
inp_args = arg_kwarg_vals[0]
inp_kwargs = arg_kwarg_vals[1]
args, kwargs = constrain_to_fake_tensors(
args, kwargs, inp_args, inp_kwargs
)
else:
args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index]
result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type]
self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined]
else:

View File

@ -2413,6 +2413,31 @@ def require_channels_last(_, *args, **kwargs):
return args, kwargs
def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs):
def apply_constraint(arg, fake_arg):
if isinstance(arg, ir.IRNode):
meta_stride_expr = [
s.node.expr if isinstance(s, torch.SymInt) else s
for s in fake_arg.stride()
]
return ir.ExternKernel.require_exact_strides(arg, meta_stride_expr)
if isinstance(arg, dict):
return {
key: apply_constraint(arg[key], fake_arg[key]) for key in arg.keys()
}
elif isinstance(arg, (tuple, list)):
return type(arg)(
apply_constraint(a, f_a) for (a, f_a) in zip(arg, fake_arg)
)
return arg
args = tuple(
apply_constraint(arg, fake_arg) for arg, fake_arg in zip(args, fake_args)
)
kwargs = {k: apply_constraint(v, fake_kwargs[k]) for k, v in kwargs.items()}
return args, kwargs
def constrain_to_fx_strides(fx_node, *args, **kwargs):
def apply_constraint(arg, fx_arg):
if isinstance(arg, ir.IRNode):

View File

@ -52,7 +52,11 @@ from torch._subclasses.fake_tensor import (
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx import GraphModule, Proxy, Tracer
from torch.fx.graph_module import _assign_attr
from torch.fx.node import _side_effectful_need_to_be_preserved_pre_dispatch
from torch.fx.node import (
_side_effectful_need_to_be_preserved_pre_dispatch,
Argument,
Target,
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.nn import Module
from torch.overrides import TorchFunctionMode
@ -1087,6 +1091,40 @@ class PythonKeyTracer(Tracer):
else:
return e
def create_node(
self,
kind: str,
target: Target,
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
) -> torch.fx.Node:
node = super().create_node(kind, target, args, kwargs, name, type_expr) # type: ignore[arg-type]
def map_fn(v: Any) -> Optional[_ExtractValType]:
if not isinstance(v, torch.fx.Node) or "val" not in v.meta:
return None
val = v.meta["val"]
# other subclasses like FunctionalTensor error on `extract_val`
# "Attempting to use FunctionalTensor on its own." just store FakeTensors for now
if isinstance(val, torch.Tensor) and not isinstance(val, FakeTensor):
return None
return extract_val(v.meta["val"])
# TODO: opt-in mechanism ?
if isinstance(
target,
(
torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional,
torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation,
),
):
arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type]
node.meta["arg_kwarg_vals"] = (arg_inp, kwarg_inp)
return node
def _make_temp_remove_mode_context_manager(
mode_ty: type[TorchFunctionMode],
@ -1186,7 +1224,7 @@ def wrap_key(
track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
return get_proxy_slot(t, tracer, t, lambda x: x.proxy)
return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined]
out = f(*tensors) # type:ignore[call-arg]
out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out)

View File

@ -114,6 +114,7 @@ _COPY_META_FIELDS = [
"_numeric_debug_handle", # TODO deprecated
"custom",
"partitioner_tag",
"arg_kwarg_vals",
]