Revert "python definitely_contiguous-> is_contiguous_or_false (#156515)"

This reverts commit 4c0091fda6.

Reverted https://github.com/pytorch/pytorch/pull/156515 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to cause some torch.export failures internally ([comment](https://github.com/pytorch/pytorch/pull/156515#issuecomment-3014104570))
This commit is contained in:
PyTorch MergeBot 2025-06-27 19:07:06 +00:00
parent 2860f5c4f5
commit 75a7d9e868
6 changed files with 31 additions and 38 deletions

View File

@ -3337,7 +3337,7 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
)
with ctx():
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].
# but not anymore since we use contiguous_or_false .
# but not anymore since we use definitely_contiguous .
# We need a way to mark strides unbacked to avoid the recompilation here.
x = torch.randn(10, 10)
torch._dynamo.decorators.mark_unbacked(x, 0)

View File

@ -26,12 +26,12 @@ from torch._prims_common import (
BoolLike,
corresponding_complex_dtype,
corresponding_real_dtype,
definitely_contiguous,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
FloatLike,
IntLike,
is_contiguous,
is_contiguous_or_false,
make_contiguous_strides_for,
Number,
suggest_memory_format,
@ -328,7 +328,7 @@ def _view_unbacked_meta(a, shape, size_oblivious_enabled=True):
if len(shape) == len(a.shape) and guard_or_false(sym_eq(shape, a.shape)):
return view_of(a)
if is_contiguous_or_false(a) if size_oblivious_enabled else is_contiguous(a):
if definitely_contiguous(a) if size_oblivious_enabled else is_contiguous(a):
strides = utils.make_contiguous_strides_for(shape)
return a.as_strided(shape, strides)

View File

@ -276,25 +276,19 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
return True
expected_stride = 1
expected_stride_max = 1
for x, y in reversed(tuple(zip(a.shape, a.stride()))):
# Skips checking strides when a dimension has length 1.
if maybe_guard_or_false(x == 1):
continue
if maybe_guard_or_true(y != expected_stride) and maybe_guard_or_true(
y != expected_stride_max
):
if maybe_guard_or_true(y != expected_stride):
return False
# We symbolically check both paths to maximize the cases where this function
# returns true. This is because make_contiguous_strides_for adds the max
# symbolically, and in some other situations the max might not be there.
# And we want to ensure we return true in both cases.
expected_stride_max *= x if is_nested_int(x) else sym_max(x, 1) # type:ignore[assignment]
expected_stride *= x
# if x is 0 then a is contiguous anyway. So in the check above for non-contiguity condition we can
# can assume x is not 0 in expected_stride equation. This make the check consistent with
# make_contiguous_strides_for. If we make a tensor and used strides from make_contiguous_strides_for
# and then called definitely_contiguous we should get True.
expected_stride *= x if is_nested_int(x) else sym_max(x, 1) # type:ignore[assignment]
return True
@ -391,22 +385,22 @@ def is_contiguous_for_memory_format( # type: ignore[return]
)
def is_contiguous_or_false(a: TensorLikeType) -> bool:
def definitely_contiguous(a: TensorLikeType) -> bool:
return is_contiguous(a, false_if_dde=True)
# similar to is_channels_last_contiguous_2d but return false on data dependency.
def is_channels_last_contiguous_or_false_2d(a: Tensor) -> bool:
def definitely_channels_last_contiguous_2d(a: Tensor) -> bool:
return is_channels_last_contiguous_2d(a, false_if_dde=True)
# similar to is_channels_last_contiguous_3d but return false on data dependency.
def is_channels_last_contiguous_or_false_3d(a: Tensor) -> bool:
def definitely_channels_last_contiguous_3d(a: Tensor) -> bool:
return is_channels_last_contiguous_3d(a, false_if_dde=True)
# similar to is_contiguous_for_memory_format but return false on data dependency.
def contiguous_for_memory_format_or_false( # type: ignore[return]
def definitely_contiguous_for_memory_format( # type: ignore[return]
a: Tensor, *, memory_format: torch.memory_format
) -> bool:
return is_contiguous_for_memory_format(
@ -432,10 +426,10 @@ def is_channels_last_contiguous(a: Tensor) -> bool:
# similar to is_channels_last_contiguous but return false on data dependency.
def is_channels_last_contiguous_or_false(a: Tensor) -> bool:
return is_channels_last_contiguous_or_false_2d(
def definitely_channels_last_contiguous(a: Tensor) -> bool:
return definitely_channels_last_contiguous_2d(
a
) or is_channels_last_contiguous_or_false_3d(a)
) or definitely_channels_last_contiguous_3d(a)
def is_non_overlapping_and_dense(a: Tensor) -> bool:
@ -452,7 +446,7 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
return False
# Short-circuits if the tensor is already contiguous or channels-last contiguous
if is_contiguous_or_false(a) or is_channels_last_contiguous_or_false(a):
if definitely_contiguous(a) or definitely_channels_last_contiguous(a):
return True
# The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp
@ -547,10 +541,10 @@ def compute_elementwise_output_logical_to_physical_perm(
is_contiguous = True
is_channels_last = True
for t in tensors:
is_contiguous = is_contiguous and contiguous_for_memory_format_or_false(
is_contiguous = is_contiguous and definitely_contiguous_for_memory_format(
t, memory_format=torch.contiguous_format
)
is_channels_last = is_channels_last and contiguous_for_memory_format_or_false(
is_channels_last = is_channels_last and definitely_contiguous_for_memory_format(
t, memory_format=torch.channels_last
)

View File

@ -19,7 +19,8 @@ import torch.utils._pytree as pytree
from torch import sym_float, sym_int
from torch._prims_common import (
BoolLike,
contiguous_for_memory_format_or_false,
definitely_contiguous,
definitely_contiguous_for_memory_format,
DeviceLikeType,
Dim,
DimsSequenceType,
@ -29,7 +30,6 @@ from torch._prims_common import (
FloatLike,
FloatWithoutSymFloat,
IntLike,
is_contiguous_or_false,
is_weakly_lesser_type,
Number,
NumberType,
@ -2984,7 +2984,7 @@ def contiguous(
)
# TODO: make logic consistent with aten contiguous
if contiguous_for_memory_format_or_false(a, memory_format=memory_format):
if definitely_contiguous_for_memory_format(a, memory_format=memory_format):
return a
return torch.clone(a, memory_format=memory_format)
@ -3852,7 +3852,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
else:
return _a
if is_contiguous_or_false(a):
if definitely_contiguous(a):
# Special-cases for nd_to_1d
if len(shape) == 1 and a.ndim > 1:
return torch.as_strided(a, [a.numel()], [1])

View File

@ -12,7 +12,7 @@ import torch._logging
from torch._dispatch.python import no_python_dispatcher
from torch._ops import OpOverload
from torch._prims_common import (
contiguous_for_memory_format_or_false,
definitely_contiguous_for_memory_format,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
is_boolean_dtype,
@ -1020,8 +1020,7 @@ def make_fast_binary_impl(
# compute_fast_setup_type
definitely_contiguous = True
definitely_channels_last = True
# TODO: is_non-overlapping_and_dense not bound from Python
# TODO: is_non-overlapping_and_dense (not bound from Python
# no inplace, no out, everything defined
if is_noncontiguous_supported(common_device):
@ -1030,13 +1029,13 @@ def make_fast_binary_impl(
continue
definitely_contiguous = (
definitely_contiguous
and contiguous_for_memory_format_or_false(
and definitely_contiguous_for_memory_format(
op, memory_format=torch.contiguous_format
)
)
definitely_channels_last = (
definitely_channels_last
and contiguous_for_memory_format_or_false(
and definitely_contiguous_for_memory_format(
op, memory_format=torch.channels_last
)
)

View File

@ -7,7 +7,7 @@ import torch
import torch.fx
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import detect_fake_mode
from torch._prims_common import contiguous_for_memory_format_or_false
from torch._prims_common import definitely_contiguous_for_memory_format
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx._compatibility import compatibility
from torch.fx.node import map_aggregate, Node
@ -35,8 +35,8 @@ class TensorMetadata(NamedTuple):
# When include_contiguity is True, we will set contiguity when its always true for the tensor.
# Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3).
# In such situation contiguity is not set. We could also make it a tri-state i.e: (def_contiguous,
# def_not_contiguous and unknown).
# In such situation contiguity is not set. We could also make it a tri-state i.e: (definitely_contiguous,
# contiguous, and unknown).
def _extract_tensor_metadata(
result: torch.Tensor, include_contiguity=True
) -> TensorMetadata:
@ -57,7 +57,7 @@ def _extract_tensor_metadata(
torch.channels_last_3d,
}
for query_format in memory_formats:
if contiguous_for_memory_format_or_false(
if definitely_contiguous_for_memory_format(
result, memory_format=query_format
):
memory_format = query_format