mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Resubmit] Record input strides at time of tracing, constrain to them for triton fn (#147861)
Resubmit of https://github.com/pytorch/pytorch/pull/145448. it lost its changes on rebase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147861 Approved by: https://github.com/zou3519
This commit is contained in:
parent
ba25e26baa
commit
c839fa4dd2
|
|
@ -8,6 +8,7 @@ import logging
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.testing
|
import torch._dynamo.testing
|
||||||
import torch._inductor.test_case
|
import torch._inductor.test_case
|
||||||
|
import torch.utils._pytree as pytree
|
||||||
from torch._dynamo import config as dynamo_config
|
from torch._dynamo import config as dynamo_config
|
||||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||||
generate_ttir,
|
generate_ttir,
|
||||||
|
|
@ -15,6 +16,11 @@ from torch._higher_order_ops.triton_kernel_wrap import (
|
||||||
triton_kernel_wrapper_mutation,
|
triton_kernel_wrapper_mutation,
|
||||||
)
|
)
|
||||||
from torch._inductor import config as inductor_config, metrics
|
from torch._inductor import config as inductor_config, metrics
|
||||||
|
from torch._inductor.pattern_matcher import (
|
||||||
|
CallFunctionVarArgs,
|
||||||
|
PatternMatcherPass,
|
||||||
|
register_graph_pattern,
|
||||||
|
)
|
||||||
from torch._inductor.utils import run_and_get_code, triton_version_uses_attrs_dict
|
from torch._inductor.utils import run_and_get_code, triton_version_uses_attrs_dict
|
||||||
from torch._library import capture_triton
|
from torch._library import capture_triton
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
|
|
@ -3374,6 +3380,75 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
||||||
self.assertEqual(status[-1], False)
|
self.assertEqual(status[-1], False)
|
||||||
self.assertEqual(z, (x + y) * 2)
|
self.assertEqual(z, (x + y) * 2)
|
||||||
|
|
||||||
|
@requires_gpu
|
||||||
|
def test_preserves_strides(self):
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def add_kernel(
|
||||||
|
in_ptr0,
|
||||||
|
in_ptr1,
|
||||||
|
out_ptr,
|
||||||
|
n_elements,
|
||||||
|
BLOCK_SIZE: "tl.constexpr",
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < n_elements
|
||||||
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||||
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||||
|
output = x + y
|
||||||
|
tl.store(out_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
|
x = torch.randn(4, 4, 2, 2, device="cuda")
|
||||||
|
other = torch.randn(4, 4, 2, 2, device="cuda")
|
||||||
|
|
||||||
|
def f(x, other):
|
||||||
|
y = x.transpose(2, 3).contiguous().transpose(2, 3)
|
||||||
|
z = y.sin().transpose(2, 3)
|
||||||
|
grid = (z.numel(),)
|
||||||
|
out = torch.empty_like(other)
|
||||||
|
add_kernel[grid](z, other, out, z.numel(), BLOCK_SIZE=16)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class _CustomPass(PatternMatcherPass):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(self, g: torch.fx.Graph):
|
||||||
|
self.apply(g)
|
||||||
|
|
||||||
|
g = _CustomPass()
|
||||||
|
called = False
|
||||||
|
|
||||||
|
@register_graph_pattern(
|
||||||
|
CallFunctionVarArgs(torch.ops.aten.permute),
|
||||||
|
pass_dict=g,
|
||||||
|
)
|
||||||
|
def _(match, *args, **kwargs):
|
||||||
|
flat_args, spec = pytree.tree_flatten((args, kwargs))
|
||||||
|
|
||||||
|
def decomp(*flat_args):
|
||||||
|
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
||||||
|
return torch.ops.aten.permute(*args, **kwargs).clone(
|
||||||
|
memory_format=torch.channels_last
|
||||||
|
)
|
||||||
|
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
match.replace_by_example(decomp, flat_args)
|
||||||
|
|
||||||
|
from torch._inductor import config
|
||||||
|
|
||||||
|
with config.patch(
|
||||||
|
post_grad_custom_post_pass=g,
|
||||||
|
):
|
||||||
|
f_compile = torch.compile(f)
|
||||||
|
self.assertEqual(f(x, other), f_compile(x, other))
|
||||||
|
self.assertTrue(called)
|
||||||
|
|
||||||
@requires_gpu
|
@requires_gpu
|
||||||
@common_utils.parametrize("dynamic", [False, True])
|
@common_utils.parametrize("dynamic", [False, True])
|
||||||
@common_utils.parametrize("autotune", [False, True])
|
@common_utils.parametrize("autotune", [False, True])
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,7 @@ from .ir import (
|
||||||
TorchBindObject,
|
TorchBindObject,
|
||||||
)
|
)
|
||||||
from .lowering import (
|
from .lowering import (
|
||||||
|
constrain_to_fake_tensors,
|
||||||
constrain_to_fx_strides,
|
constrain_to_fx_strides,
|
||||||
FALLBACK_ALLOW_LIST,
|
FALLBACK_ALLOW_LIST,
|
||||||
fallback_handler,
|
fallback_handler,
|
||||||
|
|
@ -232,6 +233,13 @@ def mark_nodes_dislike_padding(
|
||||||
)
|
)
|
||||||
|
|
||||||
for cur in reversed(g.nodes):
|
for cur in reversed(g.nodes):
|
||||||
|
if isinstance(
|
||||||
|
cur.target,
|
||||||
|
torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation,
|
||||||
|
):
|
||||||
|
cur.meta["dislike_padding"] = True
|
||||||
|
continue
|
||||||
|
|
||||||
op = _get_overload_packet(cur)
|
op = _get_overload_packet(cur)
|
||||||
if not op:
|
if not op:
|
||||||
continue
|
continue
|
||||||
|
|
@ -1354,8 +1362,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
for name in mutated:
|
for name in mutated:
|
||||||
old_arg = old_kwargs["kwargs"][name]
|
old_arg = old_kwargs["kwargs"][name]
|
||||||
new_arg = new_kwargs["kwargs"][name]
|
new_arg = new_kwargs["kwargs"][name]
|
||||||
if old_arg is new_args:
|
if old_arg is new_arg:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {})
|
self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {})
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -1438,6 +1447,14 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
):
|
):
|
||||||
old_args = args # type: ignore[possibly-undefined]
|
old_args = args # type: ignore[possibly-undefined]
|
||||||
old_kwargs = kwargs # type: ignore[possibly-undefined]
|
old_kwargs = kwargs # type: ignore[possibly-undefined]
|
||||||
|
|
||||||
|
if arg_kwarg_vals := n.meta.get("arg_kwarg_vals"):
|
||||||
|
inp_args = arg_kwarg_vals[0]
|
||||||
|
inp_kwargs = arg_kwarg_vals[1]
|
||||||
|
args, kwargs = constrain_to_fake_tensors(
|
||||||
|
args, kwargs, inp_args, inp_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index]
|
args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index]
|
||||||
result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type]
|
result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type]
|
||||||
self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined]
|
self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined]
|
||||||
|
|
|
||||||
|
|
@ -2413,6 +2413,31 @@ def require_channels_last(_, *args, **kwargs):
|
||||||
return args, kwargs
|
return args, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs):
|
||||||
|
def apply_constraint(arg, fake_arg):
|
||||||
|
if isinstance(arg, ir.IRNode):
|
||||||
|
meta_stride_expr = [
|
||||||
|
s.node.expr if isinstance(s, torch.SymInt) else s
|
||||||
|
for s in fake_arg.stride()
|
||||||
|
]
|
||||||
|
return ir.ExternKernel.require_exact_strides(arg, meta_stride_expr)
|
||||||
|
if isinstance(arg, dict):
|
||||||
|
return {
|
||||||
|
key: apply_constraint(arg[key], fake_arg[key]) for key in arg.keys()
|
||||||
|
}
|
||||||
|
elif isinstance(arg, (tuple, list)):
|
||||||
|
return type(arg)(
|
||||||
|
apply_constraint(a, f_a) for (a, f_a) in zip(arg, fake_arg)
|
||||||
|
)
|
||||||
|
return arg
|
||||||
|
|
||||||
|
args = tuple(
|
||||||
|
apply_constraint(arg, fake_arg) for arg, fake_arg in zip(args, fake_args)
|
||||||
|
)
|
||||||
|
kwargs = {k: apply_constraint(v, fake_kwargs[k]) for k, v in kwargs.items()}
|
||||||
|
return args, kwargs
|
||||||
|
|
||||||
|
|
||||||
def constrain_to_fx_strides(fx_node, *args, **kwargs):
|
def constrain_to_fx_strides(fx_node, *args, **kwargs):
|
||||||
def apply_constraint(arg, fx_arg):
|
def apply_constraint(arg, fx_arg):
|
||||||
if isinstance(arg, ir.IRNode):
|
if isinstance(arg, ir.IRNode):
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,11 @@ from torch._subclasses.fake_tensor import (
|
||||||
from torch._subclasses.meta_utils import is_sparse_any
|
from torch._subclasses.meta_utils import is_sparse_any
|
||||||
from torch.fx import GraphModule, Proxy, Tracer
|
from torch.fx import GraphModule, Proxy, Tracer
|
||||||
from torch.fx.graph_module import _assign_attr
|
from torch.fx.graph_module import _assign_attr
|
||||||
from torch.fx.node import _side_effectful_need_to_be_preserved_pre_dispatch
|
from torch.fx.node import (
|
||||||
|
_side_effectful_need_to_be_preserved_pre_dispatch,
|
||||||
|
Argument,
|
||||||
|
Target,
|
||||||
|
)
|
||||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.overrides import TorchFunctionMode
|
from torch.overrides import TorchFunctionMode
|
||||||
|
|
@ -1087,6 +1091,40 @@ class PythonKeyTracer(Tracer):
|
||||||
else:
|
else:
|
||||||
return e
|
return e
|
||||||
|
|
||||||
|
def create_node(
|
||||||
|
self,
|
||||||
|
kind: str,
|
||||||
|
target: Target,
|
||||||
|
args: tuple[Argument, ...],
|
||||||
|
kwargs: dict[str, Argument],
|
||||||
|
name: Optional[str] = None,
|
||||||
|
type_expr: Optional[Any] = None,
|
||||||
|
) -> torch.fx.Node:
|
||||||
|
node = super().create_node(kind, target, args, kwargs, name, type_expr) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def map_fn(v: Any) -> Optional[_ExtractValType]:
|
||||||
|
if not isinstance(v, torch.fx.Node) or "val" not in v.meta:
|
||||||
|
return None
|
||||||
|
val = v.meta["val"]
|
||||||
|
# other subclasses like FunctionalTensor error on `extract_val`
|
||||||
|
# "Attempting to use FunctionalTensor on its own." just store FakeTensors for now
|
||||||
|
if isinstance(val, torch.Tensor) and not isinstance(val, FakeTensor):
|
||||||
|
return None
|
||||||
|
return extract_val(v.meta["val"])
|
||||||
|
|
||||||
|
# TODO: opt-in mechanism ?
|
||||||
|
if isinstance(
|
||||||
|
target,
|
||||||
|
(
|
||||||
|
torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional,
|
||||||
|
torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type]
|
||||||
|
node.meta["arg_kwarg_vals"] = (arg_inp, kwarg_inp)
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
def _make_temp_remove_mode_context_manager(
|
def _make_temp_remove_mode_context_manager(
|
||||||
mode_ty: type[TorchFunctionMode],
|
mode_ty: type[TorchFunctionMode],
|
||||||
|
|
@ -1186,7 +1224,7 @@ def wrap_key(
|
||||||
track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
|
track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
|
||||||
|
|
||||||
def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
|
def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
|
||||||
return get_proxy_slot(t, tracer, t, lambda x: x.proxy)
|
return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined]
|
||||||
|
|
||||||
out = f(*tensors) # type:ignore[call-arg]
|
out = f(*tensors) # type:ignore[call-arg]
|
||||||
out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out)
|
out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out)
|
||||||
|
|
|
||||||
|
|
@ -114,6 +114,7 @@ _COPY_META_FIELDS = [
|
||||||
"_numeric_debug_handle", # TODO deprecated
|
"_numeric_debug_handle", # TODO deprecated
|
||||||
"custom",
|
"custom",
|
||||||
"partitioner_tag",
|
"partitioner_tag",
|
||||||
|
"arg_kwarg_vals",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user