Revert "Refactor layout constraint selection logic (#148104)"

This reverts commit 2e7c9d33e7.

Reverted https://github.com/pytorch/pytorch/pull/148104 on behalf of https://github.com/atalman due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/14357056427/job/40251630946) [HUD commit link](2e7c9d33e7) ([comment](https://github.com/pytorch/pytorch/pull/148104#issuecomment-2790369493))
This commit is contained in:
PyTorch MergeBot 2025-04-09 16:49:48 +00:00
parent a0e796df03
commit 01568cb17a
4 changed files with 50 additions and 57 deletions

View File

@ -126,7 +126,7 @@ sleep_sec_TESTING_ONLY: Optional[int] = None
# If the custom op does not have a layout constraint tag already
# then we assume the following applies.
custom_op_default_layout_constraint: Literal[
"needs_exact_strides", "needs_fixed_stride_order", "flexible_layout"
"needs_fixed_stride_order", "flexible_layout"
] = "needs_fixed_stride_order"
# The default layout constraint for user-defined triton kernels.

View File

@ -80,13 +80,11 @@ from .lowering import (
FALLBACK_ALLOW_LIST,
fallback_handler,
fallback_node_due_to_unsupported_type,
get_layout_constraint_tag,
lowerings,
make_fallback,
maybe_layout_constraints,
needs_realized_inputs,
require_contiguous,
tag_to_layout_constraint,
unsupported_output_tensor,
)
from .runtime import autotune_cache
@ -246,14 +244,6 @@ def mark_nodes_dislike_padding(
cur.meta["dislike_padding"] = True
continue
if (
isinstance(cur.target, torch._ops.OpOverload)
and get_layout_constraint_tag(cur.target)
== torch._C.Tag.needs_exact_strides
):
cur.meta["dislike_padding"] = True
continue
op = _get_overload_packet(cur)
if not op:
continue
@ -1160,26 +1150,34 @@ class GraphLowering(torch.fx.Interpreter):
error.operator_str(target, args, kwargs),
)
tag = get_layout_constraint_tag(target, with_default=False)
if (
tag is None
and torch._library.utils.is_builtin(target)
and self.is_backward
):
# for implicit fallback ATen ops during backward, if there
# is no layout constraint tag, we conservatively require contiguous
# input since some eager kernels do not
# support non-contiguous inputs. Otherwise they may silently cause
# accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
# We only do this For ATen ops and for backward.
# use contiguous unless the (custom) op asks something else
# explicitly
if torch._C.Tag.needs_exact_strides in target.tags:
decided_constraint = constrain_to_fake_tensors # type: ignore[assignment]
elif torch._C.Tag.needs_fixed_stride_order in target.tags:
decided_constraint = constrain_to_fx_strides # type: ignore[assignment]
elif torch._C.Tag.flexible_layout in target.tags:
decided_constraint = None # type: ignore[assignment]
else:
# If there are no tags, we do different things depending on
# if it's a builtin ATen/prim ops or custom ops.
# For ATen ops, we require_contiguous to fix https://github.com/pytorch/pytorch/issues/140452
# For custom ops, we constrain_to_fx_strides to maintain the
# behavior of PyTorch 2.5: https://github.com/pytorch/pytorch/issues/148356
#
# TODO: should really switch to "needs_fixed_stride" constraint on these
# and identify them one by one.
# For ATen ops, only apply the constraint for backward
# ops since fwd ops should work for any strides.
if torch._library.utils.is_builtin(target) and self.is_backward:
decided_constraint = require_contiguous # type: ignore[assignment]
else:
tag = get_layout_constraint_tag(target, with_default=True)
decided_constraint = tag_to_layout_constraint(tag)
# maybe_layout_constraints will decide the layout constraint for the custom op
# lazily
decided_constraint = None # type: ignore[assignment]
# for implicitly fallback ops, we conservatively requires
# contiguous input since some eager kernels does not
# support non-contiguous inputs. They may silently cause
# accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
make_fallback(target, layout_constraint=decided_constraint)
elif get_decompositions([target]):

View File

@ -157,40 +157,37 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A
return None
if fn in _maybe_layout_constraints:
return _maybe_layout_constraints[fn]
# OpOverload with custom lowerings override tag-based layout constraints
if fn in lowerings:
_maybe_layout_constraints[fn] = None
return None
# We lazily register tag-based layout constraints.
def handle_layout_constraint_tag(tag):
if tag is torch._C.Tag.needs_fixed_stride_order:
_maybe_layout_constraints[fn] = constrain_to_fx_strides
return _maybe_layout_constraints[fn]
elif tag is torch._C.Tag.flexible_layout:
_maybe_layout_constraints[fn] = None
return None
else:
raise AssertionError(f"Unknown layout constraint tag: {tag}")
tag = get_layout_constraint_tag(fn)
return handle_layout_constraint_tag(tag)
tags_by_priority = [
torch._C.Tag.needs_exact_strides,
torch._C.Tag.needs_fixed_stride_order,
torch._C.Tag.flexible_layout,
]
def get_layout_constraint_tag(fn, *, with_default=True):
def get_layout_constraint_tag(fn):
tags_by_priority = [
torch._C.Tag.needs_exact_strides,
torch._C.Tag.needs_fixed_stride_order,
torch._C.Tag.flexible_layout,
]
for tag in tags_by_priority:
if tag in fn.tags:
return tag
if with_default:
if torch._library.utils.is_builtin(fn):
return torch._C.Tag.flexible_layout
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
return None
def tag_to_layout_constraint(tag):
if tag == torch._C.Tag.needs_exact_strides:
return constrain_to_fake_tensors
if tag == torch._C.Tag.needs_fixed_stride_order:
return constrain_to_fx_strides
if tag == torch._C.Tag.flexible_layout:
return None
raise AssertionError(f"Unknown layout constraint tag: {tag}")
def assert_nyi(cond, msg):

View File

@ -1169,9 +1169,7 @@ def _should_save_eager_input_vals(
f"propagate the FakeTensor vals. Please file an issue."
)
if isinstance(target, torch._ops.OpOverload):
from torch._inductor.lowering import get_layout_constraint_tag
return get_layout_constraint_tag(target) == torch._C.Tag.needs_exact_strides
return torch._C.Tag.needs_exact_strides in target.tags
return False