mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[Resubmit] Record input strides at time of tracing, constrain to them for triton fn (#147861)
Resubmit of https://github.com/pytorch/pytorch/pull/145448. it lost its changes on rebase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147861 Approved by: https://github.com/zou3519
This commit is contained in:
parent
ba25e26baa
commit
c839fa4dd2
|
|
@ -8,6 +8,7 @@ import logging
|
|||
import torch
|
||||
import torch._dynamo.testing
|
||||
import torch._inductor.test_case
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo import config as dynamo_config
|
||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
generate_ttir,
|
||||
|
|
@ -15,6 +16,11 @@ from torch._higher_order_ops.triton_kernel_wrap import (
|
|||
triton_kernel_wrapper_mutation,
|
||||
)
|
||||
from torch._inductor import config as inductor_config, metrics
|
||||
from torch._inductor.pattern_matcher import (
|
||||
CallFunctionVarArgs,
|
||||
PatternMatcherPass,
|
||||
register_graph_pattern,
|
||||
)
|
||||
from torch._inductor.utils import run_and_get_code, triton_version_uses_attrs_dict
|
||||
from torch._library import capture_triton
|
||||
from torch.testing import FileCheck
|
||||
|
|
@ -3374,6 +3380,75 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
|||
self.assertEqual(status[-1], False)
|
||||
self.assertEqual(z, (x + y) * 2)
|
||||
|
||||
@requires_gpu
|
||||
def test_preserves_strides(self):
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
in_ptr0,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
x = torch.randn(4, 4, 2, 2, device="cuda")
|
||||
other = torch.randn(4, 4, 2, 2, device="cuda")
|
||||
|
||||
def f(x, other):
|
||||
y = x.transpose(2, 3).contiguous().transpose(2, 3)
|
||||
z = y.sin().transpose(2, 3)
|
||||
grid = (z.numel(),)
|
||||
out = torch.empty_like(other)
|
||||
add_kernel[grid](z, other, out, z.numel(), BLOCK_SIZE=16)
|
||||
return out
|
||||
|
||||
class _CustomPass(PatternMatcherPass):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, g: torch.fx.Graph):
|
||||
self.apply(g)
|
||||
|
||||
g = _CustomPass()
|
||||
called = False
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunctionVarArgs(torch.ops.aten.permute),
|
||||
pass_dict=g,
|
||||
)
|
||||
def _(match, *args, **kwargs):
|
||||
flat_args, spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
def decomp(*flat_args):
|
||||
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
||||
return torch.ops.aten.permute(*args, **kwargs).clone(
|
||||
memory_format=torch.channels_last
|
||||
)
|
||||
|
||||
nonlocal called
|
||||
called = True
|
||||
match.replace_by_example(decomp, flat_args)
|
||||
|
||||
from torch._inductor import config
|
||||
|
||||
with config.patch(
|
||||
post_grad_custom_post_pass=g,
|
||||
):
|
||||
f_compile = torch.compile(f)
|
||||
self.assertEqual(f(x, other), f_compile(x, other))
|
||||
self.assertTrue(called)
|
||||
|
||||
@requires_gpu
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
@common_utils.parametrize("autotune", [False, True])
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ from .ir import (
|
|||
TorchBindObject,
|
||||
)
|
||||
from .lowering import (
|
||||
constrain_to_fake_tensors,
|
||||
constrain_to_fx_strides,
|
||||
FALLBACK_ALLOW_LIST,
|
||||
fallback_handler,
|
||||
|
|
@ -232,6 +233,13 @@ def mark_nodes_dislike_padding(
|
|||
)
|
||||
|
||||
for cur in reversed(g.nodes):
|
||||
if isinstance(
|
||||
cur.target,
|
||||
torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation,
|
||||
):
|
||||
cur.meta["dislike_padding"] = True
|
||||
continue
|
||||
|
||||
op = _get_overload_packet(cur)
|
||||
if not op:
|
||||
continue
|
||||
|
|
@ -1354,8 +1362,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
for name in mutated:
|
||||
old_arg = old_kwargs["kwargs"][name]
|
||||
new_arg = new_kwargs["kwargs"][name]
|
||||
if old_arg is new_args:
|
||||
if old_arg is new_arg:
|
||||
continue
|
||||
|
||||
self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {})
|
||||
return
|
||||
|
||||
|
|
@ -1438,7 +1447,15 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
):
|
||||
old_args = args # type: ignore[possibly-undefined]
|
||||
old_kwargs = kwargs # type: ignore[possibly-undefined]
|
||||
args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index]
|
||||
|
||||
if arg_kwarg_vals := n.meta.get("arg_kwarg_vals"):
|
||||
inp_args = arg_kwarg_vals[0]
|
||||
inp_kwargs = arg_kwarg_vals[1]
|
||||
args, kwargs = constrain_to_fake_tensors(
|
||||
args, kwargs, inp_args, inp_kwargs
|
||||
)
|
||||
else:
|
||||
args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index]
|
||||
result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type]
|
||||
self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -2413,6 +2413,31 @@ def require_channels_last(_, *args, **kwargs):
|
|||
return args, kwargs
|
||||
|
||||
|
||||
def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs):
|
||||
def apply_constraint(arg, fake_arg):
|
||||
if isinstance(arg, ir.IRNode):
|
||||
meta_stride_expr = [
|
||||
s.node.expr if isinstance(s, torch.SymInt) else s
|
||||
for s in fake_arg.stride()
|
||||
]
|
||||
return ir.ExternKernel.require_exact_strides(arg, meta_stride_expr)
|
||||
if isinstance(arg, dict):
|
||||
return {
|
||||
key: apply_constraint(arg[key], fake_arg[key]) for key in arg.keys()
|
||||
}
|
||||
elif isinstance(arg, (tuple, list)):
|
||||
return type(arg)(
|
||||
apply_constraint(a, f_a) for (a, f_a) in zip(arg, fake_arg)
|
||||
)
|
||||
return arg
|
||||
|
||||
args = tuple(
|
||||
apply_constraint(arg, fake_arg) for arg, fake_arg in zip(args, fake_args)
|
||||
)
|
||||
kwargs = {k: apply_constraint(v, fake_kwargs[k]) for k, v in kwargs.items()}
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def constrain_to_fx_strides(fx_node, *args, **kwargs):
|
||||
def apply_constraint(arg, fx_arg):
|
||||
if isinstance(arg, ir.IRNode):
|
||||
|
|
|
|||
|
|
@ -52,7 +52,11 @@ from torch._subclasses.fake_tensor import (
|
|||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.fx import GraphModule, Proxy, Tracer
|
||||
from torch.fx.graph_module import _assign_attr
|
||||
from torch.fx.node import _side_effectful_need_to_be_preserved_pre_dispatch
|
||||
from torch.fx.node import (
|
||||
_side_effectful_need_to_be_preserved_pre_dispatch,
|
||||
Argument,
|
||||
Target,
|
||||
)
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.nn import Module
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
|
@ -1087,6 +1091,40 @@ class PythonKeyTracer(Tracer):
|
|||
else:
|
||||
return e
|
||||
|
||||
def create_node(
|
||||
self,
|
||||
kind: str,
|
||||
target: Target,
|
||||
args: tuple[Argument, ...],
|
||||
kwargs: dict[str, Argument],
|
||||
name: Optional[str] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
) -> torch.fx.Node:
|
||||
node = super().create_node(kind, target, args, kwargs, name, type_expr) # type: ignore[arg-type]
|
||||
|
||||
def map_fn(v: Any) -> Optional[_ExtractValType]:
|
||||
if not isinstance(v, torch.fx.Node) or "val" not in v.meta:
|
||||
return None
|
||||
val = v.meta["val"]
|
||||
# other subclasses like FunctionalTensor error on `extract_val`
|
||||
# "Attempting to use FunctionalTensor on its own." just store FakeTensors for now
|
||||
if isinstance(val, torch.Tensor) and not isinstance(val, FakeTensor):
|
||||
return None
|
||||
return extract_val(v.meta["val"])
|
||||
|
||||
# TODO: opt-in mechanism ?
|
||||
if isinstance(
|
||||
target,
|
||||
(
|
||||
torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional,
|
||||
torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation,
|
||||
),
|
||||
):
|
||||
arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type]
|
||||
node.meta["arg_kwarg_vals"] = (arg_inp, kwarg_inp)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
def _make_temp_remove_mode_context_manager(
|
||||
mode_ty: type[TorchFunctionMode],
|
||||
|
|
@ -1186,7 +1224,7 @@ def wrap_key(
|
|||
track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
|
||||
|
||||
def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
|
||||
return get_proxy_slot(t, tracer, t, lambda x: x.proxy)
|
||||
return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined]
|
||||
|
||||
out = f(*tensors) # type:ignore[call-arg]
|
||||
out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out)
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ _COPY_META_FIELDS = [
|
|||
"_numeric_debug_handle", # TODO deprecated
|
||||
"custom",
|
||||
"partitioner_tag",
|
||||
"arg_kwarg_vals",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user