mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] use eager stride for custom op if no tags (#148367)
Fix https://github.com/pytorch/pytorch/issues/148356 This is some sort of short term fix to recover the default behavior to apply layout constraint for custom ops when there are no tags. A longer term attempt to make sure Inductor always gets correct eager strides is here: https://github.com/pytorch/pytorch/pull/148104 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148367 Approved by: https://github.com/eellison, https://github.com/zou3519
This commit is contained in:
parent
703176e538
commit
6cc3e69103
|
|
@ -11629,6 +11629,51 @@ class CommonTemplate:
|
|||
net = torch.compile(model)
|
||||
out = net(input_t)
|
||||
|
||||
@skip_if_cpp_wrapper(
|
||||
"Without major redesign, cpp_wrapper will not support custom ops that are "
|
||||
"defined in Python."
|
||||
)
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
def test_custom_op_default_layout_constraint(self):
|
||||
with torch.library._scoped_library("mylib", "DEF") as lib:
|
||||
lib.define(
|
||||
"copy_(Tensor(a!) dst, Tensor src) -> ()",
|
||||
# No need to pass in an explicit tag since the default
|
||||
# behavior for custom op works.
|
||||
# tags=torch.Tag.needs_fixed_stride_order,
|
||||
)
|
||||
|
||||
@torch.library.impl(lib, "copy_", "Meta")
|
||||
def _(dst, src):
|
||||
return None
|
||||
|
||||
@torch.library.impl(lib, "copy_", "CompositeExplicitAutograd")
|
||||
def _(dst, src):
|
||||
if src.is_contiguous():
|
||||
dst.copy_(src + 1)
|
||||
else:
|
||||
dst.copy_(src)
|
||||
|
||||
def f(x):
|
||||
full_default_3 = torch.full([3, 3], 7.0, device=self.device)
|
||||
chunk_cat_default_1 = torch.ops.mylib.copy_.default(full_default_3, x)
|
||||
mul_out = torch.mul(full_default_3, full_default_3)
|
||||
return mul_out
|
||||
|
||||
x = (
|
||||
torch.arange(9, dtype=torch.float, device=self.device)
|
||||
.view(3, 3)
|
||||
.t()
|
||||
.contiguous()
|
||||
.t()
|
||||
)
|
||||
eager_out = f(x)
|
||||
|
||||
compiled_inductor_f = torch.compile(f, backend="inductor", fullgraph=True)
|
||||
compiled_inductor_out = compiled_inductor_f(x)
|
||||
|
||||
self.assertTrue(torch.allclose(compiled_inductor_out, eager_out))
|
||||
|
||||
@skip_if_gpu_halide # cuda error
|
||||
def test_buffer_use_after_remove(self):
|
||||
# https://github.com/pytorch/pytorch/issues/102857
|
||||
|
|
|
|||
|
|
@ -1141,14 +1141,27 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
error.operator_str(target, args, kwargs),
|
||||
)
|
||||
|
||||
decided_constraint = require_contiguous
|
||||
|
||||
# use contiguous unless the (custom) op asks something else
|
||||
# explicitly
|
||||
if torch._C.Tag.needs_fixed_stride_order in target.tags:
|
||||
decided_constraint = constrain_to_fx_strides # type: ignore[assignment]
|
||||
elif torch._C.Tag.flexible_layout in target.tags:
|
||||
decided_constraint = None # type: ignore[assignment]
|
||||
else:
|
||||
# If there are no tags, we do different things depending on
|
||||
# if it's a builtin ATen/prim ops or custom ops.
|
||||
# For ATen ops, we require_contiguous to fix https://github.com/pytorch/pytorch/issues/140452
|
||||
# For custom ops, we constrain_to_fx_strides to maintain the
|
||||
# behavior of PyTorch 2.5: https://github.com/pytorch/pytorch/issues/148356
|
||||
#
|
||||
# For ATen ops, only apply the constraint for backward
|
||||
# ops since fwd ops should work for any strides.
|
||||
if torch._library.utils.is_builtin(target) and self.is_backward:
|
||||
decided_constraint = require_contiguous # type: ignore[assignment]
|
||||
else:
|
||||
# maybe_layout_constraints will decide the layout constraint for the custom op
|
||||
# lazily
|
||||
decided_constraint = None # type: ignore[assignment]
|
||||
|
||||
# for implicitly fallback ops, we conservatively requires
|
||||
# contiguous input since some eager kernels does not
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user