[inductor] use eager stride for custom op if no tags (#148367)

Fix https://github.com/pytorch/pytorch/issues/148356

This is some sort of short term fix to recover the default behavior to apply layout constraint for custom ops when there are no tags.

A longer term attempt to make sure Inductor always gets correct eager strides is here: https://github.com/pytorch/pytorch/pull/148104

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148367
Approved by: https://github.com/eellison, https://github.com/zou3519
This commit is contained in:
Shunting Zhang 2025-03-04 14:19:43 -08:00 committed by PyTorch MergeBot
parent 703176e538
commit 6cc3e69103
2 changed files with 60 additions and 2 deletions

View File

@ -11629,6 +11629,51 @@ class CommonTemplate:
net = torch.compile(model)
out = net(input_t)
@skip_if_cpp_wrapper(
"Without major redesign, cpp_wrapper will not support custom ops that are "
"defined in Python."
)
@config.patch(implicit_fallbacks=True)
def test_custom_op_default_layout_constraint(self):
with torch.library._scoped_library("mylib", "DEF") as lib:
lib.define(
"copy_(Tensor(a!) dst, Tensor src) -> ()",
# No need to pass in an explicit tag since the default
# behavior for custom op works.
# tags=torch.Tag.needs_fixed_stride_order,
)
@torch.library.impl(lib, "copy_", "Meta")
def _(dst, src):
return None
@torch.library.impl(lib, "copy_", "CompositeExplicitAutograd")
def _(dst, src):
if src.is_contiguous():
dst.copy_(src + 1)
else:
dst.copy_(src)
def f(x):
full_default_3 = torch.full([3, 3], 7.0, device=self.device)
chunk_cat_default_1 = torch.ops.mylib.copy_.default(full_default_3, x)
mul_out = torch.mul(full_default_3, full_default_3)
return mul_out
x = (
torch.arange(9, dtype=torch.float, device=self.device)
.view(3, 3)
.t()
.contiguous()
.t()
)
eager_out = f(x)
compiled_inductor_f = torch.compile(f, backend="inductor", fullgraph=True)
compiled_inductor_out = compiled_inductor_f(x)
self.assertTrue(torch.allclose(compiled_inductor_out, eager_out))
@skip_if_gpu_halide # cuda error
def test_buffer_use_after_remove(self):
# https://github.com/pytorch/pytorch/issues/102857

View File

@ -1141,14 +1141,27 @@ class GraphLowering(torch.fx.Interpreter):
error.operator_str(target, args, kwargs),
)
decided_constraint = require_contiguous
# use contiguous unless the (custom) op asks something else
# explicitly
if torch._C.Tag.needs_fixed_stride_order in target.tags:
decided_constraint = constrain_to_fx_strides # type: ignore[assignment]
elif torch._C.Tag.flexible_layout in target.tags:
decided_constraint = None # type: ignore[assignment]
else:
# If there are no tags, we do different things depending on
# if it's a builtin ATen/prim ops or custom ops.
# For ATen ops, we require_contiguous to fix https://github.com/pytorch/pytorch/issues/140452
# For custom ops, we constrain_to_fx_strides to maintain the
# behavior of PyTorch 2.5: https://github.com/pytorch/pytorch/issues/148356
#
# For ATen ops, only apply the constraint for backward
# ops since fwd ops should work for any strides.
if torch._library.utils.is_builtin(target) and self.is_backward:
decided_constraint = require_contiguous # type: ignore[assignment]
else:
# maybe_layout_constraints will decide the layout constraint for the custom op
# lazily
decided_constraint = None # type: ignore[assignment]
# for implicitly fallback ops, we conservatively requires
# contiguous input since some eager kernels does not