mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Refactor layout constraint selection logic (#148104)"
This reverts commit2e7c9d33e7. Reverted https://github.com/pytorch/pytorch/pull/148104 on behalf of https://github.com/atalman due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/14357056427/job/40251630946) [HUD commit link](2e7c9d33e7) ([comment](https://github.com/pytorch/pytorch/pull/148104#issuecomment-2790369493))
This commit is contained in:
parent
a0e796df03
commit
01568cb17a
|
|
@ -126,7 +126,7 @@ sleep_sec_TESTING_ONLY: Optional[int] = None
|
||||||
# If the custom op does not have a layout constraint tag already
|
# If the custom op does not have a layout constraint tag already
|
||||||
# then we assume the following applies.
|
# then we assume the following applies.
|
||||||
custom_op_default_layout_constraint: Literal[
|
custom_op_default_layout_constraint: Literal[
|
||||||
"needs_exact_strides", "needs_fixed_stride_order", "flexible_layout"
|
"needs_fixed_stride_order", "flexible_layout"
|
||||||
] = "needs_fixed_stride_order"
|
] = "needs_fixed_stride_order"
|
||||||
|
|
||||||
# The default layout constraint for user-defined triton kernels.
|
# The default layout constraint for user-defined triton kernels.
|
||||||
|
|
|
||||||
|
|
@ -80,13 +80,11 @@ from .lowering import (
|
||||||
FALLBACK_ALLOW_LIST,
|
FALLBACK_ALLOW_LIST,
|
||||||
fallback_handler,
|
fallback_handler,
|
||||||
fallback_node_due_to_unsupported_type,
|
fallback_node_due_to_unsupported_type,
|
||||||
get_layout_constraint_tag,
|
|
||||||
lowerings,
|
lowerings,
|
||||||
make_fallback,
|
make_fallback,
|
||||||
maybe_layout_constraints,
|
maybe_layout_constraints,
|
||||||
needs_realized_inputs,
|
needs_realized_inputs,
|
||||||
require_contiguous,
|
require_contiguous,
|
||||||
tag_to_layout_constraint,
|
|
||||||
unsupported_output_tensor,
|
unsupported_output_tensor,
|
||||||
)
|
)
|
||||||
from .runtime import autotune_cache
|
from .runtime import autotune_cache
|
||||||
|
|
@ -246,14 +244,6 @@ def mark_nodes_dislike_padding(
|
||||||
cur.meta["dislike_padding"] = True
|
cur.meta["dislike_padding"] = True
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
|
||||||
isinstance(cur.target, torch._ops.OpOverload)
|
|
||||||
and get_layout_constraint_tag(cur.target)
|
|
||||||
== torch._C.Tag.needs_exact_strides
|
|
||||||
):
|
|
||||||
cur.meta["dislike_padding"] = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
op = _get_overload_packet(cur)
|
op = _get_overload_packet(cur)
|
||||||
if not op:
|
if not op:
|
||||||
continue
|
continue
|
||||||
|
|
@ -1160,26 +1150,34 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
error.operator_str(target, args, kwargs),
|
error.operator_str(target, args, kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
tag = get_layout_constraint_tag(target, with_default=False)
|
# use contiguous unless the (custom) op asks something else
|
||||||
if (
|
# explicitly
|
||||||
tag is None
|
if torch._C.Tag.needs_exact_strides in target.tags:
|
||||||
and torch._library.utils.is_builtin(target)
|
decided_constraint = constrain_to_fake_tensors # type: ignore[assignment]
|
||||||
and self.is_backward
|
elif torch._C.Tag.needs_fixed_stride_order in target.tags:
|
||||||
):
|
decided_constraint = constrain_to_fx_strides # type: ignore[assignment]
|
||||||
# for implicit fallback ATen ops during backward, if there
|
elif torch._C.Tag.flexible_layout in target.tags:
|
||||||
# is no layout constraint tag, we conservatively require contiguous
|
decided_constraint = None # type: ignore[assignment]
|
||||||
# input since some eager kernels do not
|
|
||||||
# support non-contiguous inputs. Otherwise they may silently cause
|
|
||||||
# accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
|
|
||||||
# We only do this For ATen ops and for backward.
|
|
||||||
#
|
|
||||||
# TODO: should really switch to "needs_fixed_stride" constraint on these
|
|
||||||
# and identify them one by one.
|
|
||||||
decided_constraint = require_contiguous # type: ignore[assignment]
|
|
||||||
else:
|
else:
|
||||||
tag = get_layout_constraint_tag(target, with_default=True)
|
# If there are no tags, we do different things depending on
|
||||||
decided_constraint = tag_to_layout_constraint(tag)
|
# if it's a builtin ATen/prim ops or custom ops.
|
||||||
|
# For ATen ops, we require_contiguous to fix https://github.com/pytorch/pytorch/issues/140452
|
||||||
|
# For custom ops, we constrain_to_fx_strides to maintain the
|
||||||
|
# behavior of PyTorch 2.5: https://github.com/pytorch/pytorch/issues/148356
|
||||||
|
#
|
||||||
|
# For ATen ops, only apply the constraint for backward
|
||||||
|
# ops since fwd ops should work for any strides.
|
||||||
|
if torch._library.utils.is_builtin(target) and self.is_backward:
|
||||||
|
decided_constraint = require_contiguous # type: ignore[assignment]
|
||||||
|
else:
|
||||||
|
# maybe_layout_constraints will decide the layout constraint for the custom op
|
||||||
|
# lazily
|
||||||
|
decided_constraint = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
# for implicitly fallback ops, we conservatively requires
|
||||||
|
# contiguous input since some eager kernels does not
|
||||||
|
# support non-contiguous inputs. They may silently cause
|
||||||
|
# accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
|
||||||
make_fallback(target, layout_constraint=decided_constraint)
|
make_fallback(target, layout_constraint=decided_constraint)
|
||||||
|
|
||||||
elif get_decompositions([target]):
|
elif get_decompositions([target]):
|
||||||
|
|
|
||||||
|
|
@ -157,40 +157,37 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A
|
||||||
return None
|
return None
|
||||||
if fn in _maybe_layout_constraints:
|
if fn in _maybe_layout_constraints:
|
||||||
return _maybe_layout_constraints[fn]
|
return _maybe_layout_constraints[fn]
|
||||||
return None
|
# OpOverload with custom lowerings override tag-based layout constraints
|
||||||
|
if fn in lowerings:
|
||||||
|
_maybe_layout_constraints[fn] = None
|
||||||
|
return None
|
||||||
|
# We lazily register tag-based layout constraints.
|
||||||
|
|
||||||
|
def handle_layout_constraint_tag(tag):
|
||||||
|
if tag is torch._C.Tag.needs_fixed_stride_order:
|
||||||
|
_maybe_layout_constraints[fn] = constrain_to_fx_strides
|
||||||
|
return _maybe_layout_constraints[fn]
|
||||||
|
elif tag is torch._C.Tag.flexible_layout:
|
||||||
|
_maybe_layout_constraints[fn] = None
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"Unknown layout constraint tag: {tag}")
|
||||||
|
|
||||||
|
tag = get_layout_constraint_tag(fn)
|
||||||
|
return handle_layout_constraint_tag(tag)
|
||||||
|
|
||||||
|
|
||||||
tags_by_priority = [
|
def get_layout_constraint_tag(fn):
|
||||||
torch._C.Tag.needs_exact_strides,
|
|
||||||
torch._C.Tag.needs_fixed_stride_order,
|
|
||||||
torch._C.Tag.flexible_layout,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_layout_constraint_tag(fn, *, with_default=True):
|
|
||||||
tags_by_priority = [
|
tags_by_priority = [
|
||||||
torch._C.Tag.needs_exact_strides,
|
|
||||||
torch._C.Tag.needs_fixed_stride_order,
|
torch._C.Tag.needs_fixed_stride_order,
|
||||||
torch._C.Tag.flexible_layout,
|
torch._C.Tag.flexible_layout,
|
||||||
]
|
]
|
||||||
for tag in tags_by_priority:
|
for tag in tags_by_priority:
|
||||||
if tag in fn.tags:
|
if tag in fn.tags:
|
||||||
return tag
|
return tag
|
||||||
if with_default:
|
if torch._library.utils.is_builtin(fn):
|
||||||
if torch._library.utils.is_builtin(fn):
|
return torch._C.Tag.flexible_layout
|
||||||
return torch._C.Tag.flexible_layout
|
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
|
||||||
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def tag_to_layout_constraint(tag):
|
|
||||||
if tag == torch._C.Tag.needs_exact_strides:
|
|
||||||
return constrain_to_fake_tensors
|
|
||||||
if tag == torch._C.Tag.needs_fixed_stride_order:
|
|
||||||
return constrain_to_fx_strides
|
|
||||||
if tag == torch._C.Tag.flexible_layout:
|
|
||||||
return None
|
|
||||||
raise AssertionError(f"Unknown layout constraint tag: {tag}")
|
|
||||||
|
|
||||||
|
|
||||||
def assert_nyi(cond, msg):
|
def assert_nyi(cond, msg):
|
||||||
|
|
|
||||||
|
|
@ -1169,9 +1169,7 @@ def _should_save_eager_input_vals(
|
||||||
f"propagate the FakeTensor vals. Please file an issue."
|
f"propagate the FakeTensor vals. Please file an issue."
|
||||||
)
|
)
|
||||||
if isinstance(target, torch._ops.OpOverload):
|
if isinstance(target, torch._ops.OpOverload):
|
||||||
from torch._inductor.lowering import get_layout_constraint_tag
|
return torch._C.Tag.needs_exact_strides in target.tags
|
||||||
|
|
||||||
return get_layout_constraint_tag(target) == torch._C.Tag.needs_exact_strides
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user