Revert "Refactor layout constraint selection logic (#148104)"

This reverts commit 2e7c9d33e7.

Reverted https://github.com/pytorch/pytorch/pull/148104 on behalf of https://github.com/atalman due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/14357056427/job/40251630946) [HUD commit link](2e7c9d33e7) ([comment](https://github.com/pytorch/pytorch/pull/148104#issuecomment-2790369493))
This commit is contained in:
PyTorch MergeBot 2025-04-09 16:49:48 +00:00
parent a0e796df03
commit 01568cb17a
4 changed files with 50 additions and 57 deletions

View File

@ -126,7 +126,7 @@ sleep_sec_TESTING_ONLY: Optional[int] = None
# If the custom op does not have a layout constraint tag already # If the custom op does not have a layout constraint tag already
# then we assume the following applies. # then we assume the following applies.
custom_op_default_layout_constraint: Literal[ custom_op_default_layout_constraint: Literal[
"needs_exact_strides", "needs_fixed_stride_order", "flexible_layout" "needs_fixed_stride_order", "flexible_layout"
] = "needs_fixed_stride_order" ] = "needs_fixed_stride_order"
# The default layout constraint for user-defined triton kernels. # The default layout constraint for user-defined triton kernels.

View File

@ -80,13 +80,11 @@ from .lowering import (
FALLBACK_ALLOW_LIST, FALLBACK_ALLOW_LIST,
fallback_handler, fallback_handler,
fallback_node_due_to_unsupported_type, fallback_node_due_to_unsupported_type,
get_layout_constraint_tag,
lowerings, lowerings,
make_fallback, make_fallback,
maybe_layout_constraints, maybe_layout_constraints,
needs_realized_inputs, needs_realized_inputs,
require_contiguous, require_contiguous,
tag_to_layout_constraint,
unsupported_output_tensor, unsupported_output_tensor,
) )
from .runtime import autotune_cache from .runtime import autotune_cache
@ -246,14 +244,6 @@ def mark_nodes_dislike_padding(
cur.meta["dislike_padding"] = True cur.meta["dislike_padding"] = True
continue continue
if (
isinstance(cur.target, torch._ops.OpOverload)
and get_layout_constraint_tag(cur.target)
== torch._C.Tag.needs_exact_strides
):
cur.meta["dislike_padding"] = True
continue
op = _get_overload_packet(cur) op = _get_overload_packet(cur)
if not op: if not op:
continue continue
@ -1160,26 +1150,34 @@ class GraphLowering(torch.fx.Interpreter):
error.operator_str(target, args, kwargs), error.operator_str(target, args, kwargs),
) )
tag = get_layout_constraint_tag(target, with_default=False) # use contiguous unless the (custom) op asks something else
if ( # explicitly
tag is None if torch._C.Tag.needs_exact_strides in target.tags:
and torch._library.utils.is_builtin(target) decided_constraint = constrain_to_fake_tensors # type: ignore[assignment]
and self.is_backward elif torch._C.Tag.needs_fixed_stride_order in target.tags:
): decided_constraint = constrain_to_fx_strides # type: ignore[assignment]
# for implicit fallback ATen ops during backward, if there elif torch._C.Tag.flexible_layout in target.tags:
# is no layout constraint tag, we conservatively require contiguous decided_constraint = None # type: ignore[assignment]
# input since some eager kernels do not
# support non-contiguous inputs. Otherwise they may silently cause
# accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
# We only do this For ATen ops and for backward.
#
# TODO: should really switch to "needs_fixed_stride" constraint on these
# and identify them one by one.
decided_constraint = require_contiguous # type: ignore[assignment]
else: else:
tag = get_layout_constraint_tag(target, with_default=True) # If there are no tags, we do different things depending on
decided_constraint = tag_to_layout_constraint(tag) # if it's a builtin ATen/prim ops or custom ops.
# For ATen ops, we require_contiguous to fix https://github.com/pytorch/pytorch/issues/140452
# For custom ops, we constrain_to_fx_strides to maintain the
# behavior of PyTorch 2.5: https://github.com/pytorch/pytorch/issues/148356
#
# For ATen ops, only apply the constraint for backward
# ops since fwd ops should work for any strides.
if torch._library.utils.is_builtin(target) and self.is_backward:
decided_constraint = require_contiguous # type: ignore[assignment]
else:
# maybe_layout_constraints will decide the layout constraint for the custom op
# lazily
decided_constraint = None # type: ignore[assignment]
# for implicitly fallback ops, we conservatively requires
# contiguous input since some eager kernels does not
# support non-contiguous inputs. They may silently cause
# accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
make_fallback(target, layout_constraint=decided_constraint) make_fallback(target, layout_constraint=decided_constraint)
elif get_decompositions([target]): elif get_decompositions([target]):

View File

@ -157,40 +157,37 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A
return None return None
if fn in _maybe_layout_constraints: if fn in _maybe_layout_constraints:
return _maybe_layout_constraints[fn] return _maybe_layout_constraints[fn]
return None # OpOverload with custom lowerings override tag-based layout constraints
if fn in lowerings:
_maybe_layout_constraints[fn] = None
return None
# We lazily register tag-based layout constraints.
def handle_layout_constraint_tag(tag):
if tag is torch._C.Tag.needs_fixed_stride_order:
_maybe_layout_constraints[fn] = constrain_to_fx_strides
return _maybe_layout_constraints[fn]
elif tag is torch._C.Tag.flexible_layout:
_maybe_layout_constraints[fn] = None
return None
else:
raise AssertionError(f"Unknown layout constraint tag: {tag}")
tag = get_layout_constraint_tag(fn)
return handle_layout_constraint_tag(tag)
tags_by_priority = [ def get_layout_constraint_tag(fn):
torch._C.Tag.needs_exact_strides,
torch._C.Tag.needs_fixed_stride_order,
torch._C.Tag.flexible_layout,
]
def get_layout_constraint_tag(fn, *, with_default=True):
tags_by_priority = [ tags_by_priority = [
torch._C.Tag.needs_exact_strides,
torch._C.Tag.needs_fixed_stride_order, torch._C.Tag.needs_fixed_stride_order,
torch._C.Tag.flexible_layout, torch._C.Tag.flexible_layout,
] ]
for tag in tags_by_priority: for tag in tags_by_priority:
if tag in fn.tags: if tag in fn.tags:
return tag return tag
if with_default: if torch._library.utils.is_builtin(fn):
if torch._library.utils.is_builtin(fn): return torch._C.Tag.flexible_layout
return torch._C.Tag.flexible_layout return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
return None
def tag_to_layout_constraint(tag):
if tag == torch._C.Tag.needs_exact_strides:
return constrain_to_fake_tensors
if tag == torch._C.Tag.needs_fixed_stride_order:
return constrain_to_fx_strides
if tag == torch._C.Tag.flexible_layout:
return None
raise AssertionError(f"Unknown layout constraint tag: {tag}")
def assert_nyi(cond, msg): def assert_nyi(cond, msg):

View File

@ -1169,9 +1169,7 @@ def _should_save_eager_input_vals(
f"propagate the FakeTensor vals. Please file an issue." f"propagate the FakeTensor vals. Please file an issue."
) )
if isinstance(target, torch._ops.OpOverload): if isinstance(target, torch._ops.OpOverload):
from torch._inductor.lowering import get_layout_constraint_tag return torch._C.Tag.needs_exact_strides in target.tags
return get_layout_constraint_tag(target) == torch._C.Tag.needs_exact_strides
return False return False