mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Refactor layout constraint selection logic (#148104)"
This reverts commit2e7c9d33e7. Reverted https://github.com/pytorch/pytorch/pull/148104 on behalf of https://github.com/atalman due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/14357056427/job/40251630946) [HUD commit link](2e7c9d33e7) ([comment](https://github.com/pytorch/pytorch/pull/148104#issuecomment-2790369493))
This commit is contained in:
parent
a0e796df03
commit
01568cb17a
|
|
@ -126,7 +126,7 @@ sleep_sec_TESTING_ONLY: Optional[int] = None
|
|||
# If the custom op does not have a layout constraint tag already
|
||||
# then we assume the following applies.
|
||||
custom_op_default_layout_constraint: Literal[
|
||||
"needs_exact_strides", "needs_fixed_stride_order", "flexible_layout"
|
||||
"needs_fixed_stride_order", "flexible_layout"
|
||||
] = "needs_fixed_stride_order"
|
||||
|
||||
# The default layout constraint for user-defined triton kernels.
|
||||
|
|
|
|||
|
|
@ -80,13 +80,11 @@ from .lowering import (
|
|||
FALLBACK_ALLOW_LIST,
|
||||
fallback_handler,
|
||||
fallback_node_due_to_unsupported_type,
|
||||
get_layout_constraint_tag,
|
||||
lowerings,
|
||||
make_fallback,
|
||||
maybe_layout_constraints,
|
||||
needs_realized_inputs,
|
||||
require_contiguous,
|
||||
tag_to_layout_constraint,
|
||||
unsupported_output_tensor,
|
||||
)
|
||||
from .runtime import autotune_cache
|
||||
|
|
@ -246,14 +244,6 @@ def mark_nodes_dislike_padding(
|
|||
cur.meta["dislike_padding"] = True
|
||||
continue
|
||||
|
||||
if (
|
||||
isinstance(cur.target, torch._ops.OpOverload)
|
||||
and get_layout_constraint_tag(cur.target)
|
||||
== torch._C.Tag.needs_exact_strides
|
||||
):
|
||||
cur.meta["dislike_padding"] = True
|
||||
continue
|
||||
|
||||
op = _get_overload_packet(cur)
|
||||
if not op:
|
||||
continue
|
||||
|
|
@ -1160,26 +1150,34 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
error.operator_str(target, args, kwargs),
|
||||
)
|
||||
|
||||
tag = get_layout_constraint_tag(target, with_default=False)
|
||||
if (
|
||||
tag is None
|
||||
and torch._library.utils.is_builtin(target)
|
||||
and self.is_backward
|
||||
):
|
||||
# for implicit fallback ATen ops during backward, if there
|
||||
# is no layout constraint tag, we conservatively require contiguous
|
||||
# input since some eager kernels do not
|
||||
# support non-contiguous inputs. Otherwise they may silently cause
|
||||
# accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
|
||||
# We only do this For ATen ops and for backward.
|
||||
#
|
||||
# TODO: should really switch to "needs_fixed_stride" constraint on these
|
||||
# and identify them one by one.
|
||||
decided_constraint = require_contiguous # type: ignore[assignment]
|
||||
# use contiguous unless the (custom) op asks something else
|
||||
# explicitly
|
||||
if torch._C.Tag.needs_exact_strides in target.tags:
|
||||
decided_constraint = constrain_to_fake_tensors # type: ignore[assignment]
|
||||
elif torch._C.Tag.needs_fixed_stride_order in target.tags:
|
||||
decided_constraint = constrain_to_fx_strides # type: ignore[assignment]
|
||||
elif torch._C.Tag.flexible_layout in target.tags:
|
||||
decided_constraint = None # type: ignore[assignment]
|
||||
else:
|
||||
tag = get_layout_constraint_tag(target, with_default=True)
|
||||
decided_constraint = tag_to_layout_constraint(tag)
|
||||
# If there are no tags, we do different things depending on
|
||||
# if it's a builtin ATen/prim ops or custom ops.
|
||||
# For ATen ops, we require_contiguous to fix https://github.com/pytorch/pytorch/issues/140452
|
||||
# For custom ops, we constrain_to_fx_strides to maintain the
|
||||
# behavior of PyTorch 2.5: https://github.com/pytorch/pytorch/issues/148356
|
||||
#
|
||||
# For ATen ops, only apply the constraint for backward
|
||||
# ops since fwd ops should work for any strides.
|
||||
if torch._library.utils.is_builtin(target) and self.is_backward:
|
||||
decided_constraint = require_contiguous # type: ignore[assignment]
|
||||
else:
|
||||
# maybe_layout_constraints will decide the layout constraint for the custom op
|
||||
# lazily
|
||||
decided_constraint = None # type: ignore[assignment]
|
||||
|
||||
# for implicitly fallback ops, we conservatively requires
|
||||
# contiguous input since some eager kernels does not
|
||||
# support non-contiguous inputs. They may silently cause
|
||||
# accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
|
||||
make_fallback(target, layout_constraint=decided_constraint)
|
||||
|
||||
elif get_decompositions([target]):
|
||||
|
|
|
|||
|
|
@ -157,40 +157,37 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A
|
|||
return None
|
||||
if fn in _maybe_layout_constraints:
|
||||
return _maybe_layout_constraints[fn]
|
||||
return None
|
||||
# OpOverload with custom lowerings override tag-based layout constraints
|
||||
if fn in lowerings:
|
||||
_maybe_layout_constraints[fn] = None
|
||||
return None
|
||||
# We lazily register tag-based layout constraints.
|
||||
|
||||
def handle_layout_constraint_tag(tag):
|
||||
if tag is torch._C.Tag.needs_fixed_stride_order:
|
||||
_maybe_layout_constraints[fn] = constrain_to_fx_strides
|
||||
return _maybe_layout_constraints[fn]
|
||||
elif tag is torch._C.Tag.flexible_layout:
|
||||
_maybe_layout_constraints[fn] = None
|
||||
return None
|
||||
else:
|
||||
raise AssertionError(f"Unknown layout constraint tag: {tag}")
|
||||
|
||||
tag = get_layout_constraint_tag(fn)
|
||||
return handle_layout_constraint_tag(tag)
|
||||
|
||||
|
||||
tags_by_priority = [
|
||||
torch._C.Tag.needs_exact_strides,
|
||||
torch._C.Tag.needs_fixed_stride_order,
|
||||
torch._C.Tag.flexible_layout,
|
||||
]
|
||||
|
||||
|
||||
def get_layout_constraint_tag(fn, *, with_default=True):
|
||||
def get_layout_constraint_tag(fn):
|
||||
tags_by_priority = [
|
||||
torch._C.Tag.needs_exact_strides,
|
||||
torch._C.Tag.needs_fixed_stride_order,
|
||||
torch._C.Tag.flexible_layout,
|
||||
]
|
||||
for tag in tags_by_priority:
|
||||
if tag in fn.tags:
|
||||
return tag
|
||||
if with_default:
|
||||
if torch._library.utils.is_builtin(fn):
|
||||
return torch._C.Tag.flexible_layout
|
||||
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
|
||||
return None
|
||||
|
||||
|
||||
def tag_to_layout_constraint(tag):
|
||||
if tag == torch._C.Tag.needs_exact_strides:
|
||||
return constrain_to_fake_tensors
|
||||
if tag == torch._C.Tag.needs_fixed_stride_order:
|
||||
return constrain_to_fx_strides
|
||||
if tag == torch._C.Tag.flexible_layout:
|
||||
return None
|
||||
raise AssertionError(f"Unknown layout constraint tag: {tag}")
|
||||
if torch._library.utils.is_builtin(fn):
|
||||
return torch._C.Tag.flexible_layout
|
||||
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
|
||||
|
||||
|
||||
def assert_nyi(cond, msg):
|
||||
|
|
|
|||
|
|
@ -1169,9 +1169,7 @@ def _should_save_eager_input_vals(
|
|||
f"propagate the FakeTensor vals. Please file an issue."
|
||||
)
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
from torch._inductor.lowering import get_layout_constraint_tag
|
||||
|
||||
return get_layout_constraint_tag(target) == torch._C.Tag.needs_exact_strides
|
||||
return torch._C.Tag.needs_exact_strides in target.tags
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user