mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "python definitely_contiguous-> is_contiguous_or_false (#156515)"
This reverts commit 4c0091fda6.
Reverted https://github.com/pytorch/pytorch/pull/156515 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to cause some torch.export failures internally ([comment](https://github.com/pytorch/pytorch/pull/156515#issuecomment-3014104570))
This commit is contained in:
parent
2860f5c4f5
commit
75a7d9e868
|
|
@ -3337,7 +3337,7 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
|||
)
|
||||
with ctx():
|
||||
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].
|
||||
# but not anymore since we use contiguous_or_false .
|
||||
# but not anymore since we use definitely_contiguous .
|
||||
# We need a way to mark strides unbacked to avoid the recompilation here.
|
||||
x = torch.randn(10, 10)
|
||||
torch._dynamo.decorators.mark_unbacked(x, 0)
|
||||
|
|
|
|||
|
|
@ -26,12 +26,12 @@ from torch._prims_common import (
|
|||
BoolLike,
|
||||
corresponding_complex_dtype,
|
||||
corresponding_real_dtype,
|
||||
definitely_contiguous,
|
||||
elementwise_dtypes,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
FloatLike,
|
||||
IntLike,
|
||||
is_contiguous,
|
||||
is_contiguous_or_false,
|
||||
make_contiguous_strides_for,
|
||||
Number,
|
||||
suggest_memory_format,
|
||||
|
|
@ -328,7 +328,7 @@ def _view_unbacked_meta(a, shape, size_oblivious_enabled=True):
|
|||
if len(shape) == len(a.shape) and guard_or_false(sym_eq(shape, a.shape)):
|
||||
return view_of(a)
|
||||
|
||||
if is_contiguous_or_false(a) if size_oblivious_enabled else is_contiguous(a):
|
||||
if definitely_contiguous(a) if size_oblivious_enabled else is_contiguous(a):
|
||||
strides = utils.make_contiguous_strides_for(shape)
|
||||
return a.as_strided(shape, strides)
|
||||
|
||||
|
|
|
|||
|
|
@ -276,25 +276,19 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
|
|||
return True
|
||||
|
||||
expected_stride = 1
|
||||
expected_stride_max = 1
|
||||
|
||||
for x, y in reversed(tuple(zip(a.shape, a.stride()))):
|
||||
# Skips checking strides when a dimension has length 1.
|
||||
if maybe_guard_or_false(x == 1):
|
||||
continue
|
||||
|
||||
if maybe_guard_or_true(y != expected_stride) and maybe_guard_or_true(
|
||||
y != expected_stride_max
|
||||
):
|
||||
if maybe_guard_or_true(y != expected_stride):
|
||||
return False
|
||||
|
||||
# We symbolically check both paths to maximize the cases where this function
|
||||
# returns true. This is because make_contiguous_strides_for adds the max
|
||||
# symbolically, and in some other situations the max might not be there.
|
||||
# And we want to ensure we return true in both cases.
|
||||
expected_stride_max *= x if is_nested_int(x) else sym_max(x, 1) # type:ignore[assignment]
|
||||
|
||||
expected_stride *= x
|
||||
# if x is 0 then a is contiguous anyway. So in the check above for non-contiguity condition we can
|
||||
# can assume x is not 0 in expected_stride equation. This make the check consistent with
|
||||
# make_contiguous_strides_for. If we make a tensor and used strides from make_contiguous_strides_for
|
||||
# and then called definitely_contiguous we should get True.
|
||||
expected_stride *= x if is_nested_int(x) else sym_max(x, 1) # type:ignore[assignment]
|
||||
|
||||
return True
|
||||
|
||||
|
|
@ -391,22 +385,22 @@ def is_contiguous_for_memory_format( # type: ignore[return]
|
|||
)
|
||||
|
||||
|
||||
def is_contiguous_or_false(a: TensorLikeType) -> bool:
|
||||
def definitely_contiguous(a: TensorLikeType) -> bool:
|
||||
return is_contiguous(a, false_if_dde=True)
|
||||
|
||||
|
||||
# similar to is_channels_last_contiguous_2d but return false on data dependency.
|
||||
def is_channels_last_contiguous_or_false_2d(a: Tensor) -> bool:
|
||||
def definitely_channels_last_contiguous_2d(a: Tensor) -> bool:
|
||||
return is_channels_last_contiguous_2d(a, false_if_dde=True)
|
||||
|
||||
|
||||
# similar to is_channels_last_contiguous_3d but return false on data dependency.
|
||||
def is_channels_last_contiguous_or_false_3d(a: Tensor) -> bool:
|
||||
def definitely_channels_last_contiguous_3d(a: Tensor) -> bool:
|
||||
return is_channels_last_contiguous_3d(a, false_if_dde=True)
|
||||
|
||||
|
||||
# similar to is_contiguous_for_memory_format but return false on data dependency.
|
||||
def contiguous_for_memory_format_or_false( # type: ignore[return]
|
||||
def definitely_contiguous_for_memory_format( # type: ignore[return]
|
||||
a: Tensor, *, memory_format: torch.memory_format
|
||||
) -> bool:
|
||||
return is_contiguous_for_memory_format(
|
||||
|
|
@ -432,10 +426,10 @@ def is_channels_last_contiguous(a: Tensor) -> bool:
|
|||
|
||||
|
||||
# similar to is_channels_last_contiguous but return false on data dependency.
|
||||
def is_channels_last_contiguous_or_false(a: Tensor) -> bool:
|
||||
return is_channels_last_contiguous_or_false_2d(
|
||||
def definitely_channels_last_contiguous(a: Tensor) -> bool:
|
||||
return definitely_channels_last_contiguous_2d(
|
||||
a
|
||||
) or is_channels_last_contiguous_or_false_3d(a)
|
||||
) or definitely_channels_last_contiguous_3d(a)
|
||||
|
||||
|
||||
def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
||||
|
|
@ -452,7 +446,7 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
|||
return False
|
||||
|
||||
# Short-circuits if the tensor is already contiguous or channels-last contiguous
|
||||
if is_contiguous_or_false(a) or is_channels_last_contiguous_or_false(a):
|
||||
if definitely_contiguous(a) or definitely_channels_last_contiguous(a):
|
||||
return True
|
||||
|
||||
# The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp
|
||||
|
|
@ -547,10 +541,10 @@ def compute_elementwise_output_logical_to_physical_perm(
|
|||
is_contiguous = True
|
||||
is_channels_last = True
|
||||
for t in tensors:
|
||||
is_contiguous = is_contiguous and contiguous_for_memory_format_or_false(
|
||||
is_contiguous = is_contiguous and definitely_contiguous_for_memory_format(
|
||||
t, memory_format=torch.contiguous_format
|
||||
)
|
||||
is_channels_last = is_channels_last and contiguous_for_memory_format_or_false(
|
||||
is_channels_last = is_channels_last and definitely_contiguous_for_memory_format(
|
||||
t, memory_format=torch.channels_last
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ import torch.utils._pytree as pytree
|
|||
from torch import sym_float, sym_int
|
||||
from torch._prims_common import (
|
||||
BoolLike,
|
||||
contiguous_for_memory_format_or_false,
|
||||
definitely_contiguous,
|
||||
definitely_contiguous_for_memory_format,
|
||||
DeviceLikeType,
|
||||
Dim,
|
||||
DimsSequenceType,
|
||||
|
|
@ -29,7 +30,6 @@ from torch._prims_common import (
|
|||
FloatLike,
|
||||
FloatWithoutSymFloat,
|
||||
IntLike,
|
||||
is_contiguous_or_false,
|
||||
is_weakly_lesser_type,
|
||||
Number,
|
||||
NumberType,
|
||||
|
|
@ -2984,7 +2984,7 @@ def contiguous(
|
|||
)
|
||||
|
||||
# TODO: make logic consistent with aten contiguous
|
||||
if contiguous_for_memory_format_or_false(a, memory_format=memory_format):
|
||||
if definitely_contiguous_for_memory_format(a, memory_format=memory_format):
|
||||
return a
|
||||
|
||||
return torch.clone(a, memory_format=memory_format)
|
||||
|
|
@ -3852,7 +3852,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
|
|||
else:
|
||||
return _a
|
||||
|
||||
if is_contiguous_or_false(a):
|
||||
if definitely_contiguous(a):
|
||||
# Special-cases for nd_to_1d
|
||||
if len(shape) == 1 and a.ndim > 1:
|
||||
return torch.as_strided(a, [a.numel()], [1])
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import torch._logging
|
|||
from torch._dispatch.python import no_python_dispatcher
|
||||
from torch._ops import OpOverload
|
||||
from torch._prims_common import (
|
||||
contiguous_for_memory_format_or_false,
|
||||
definitely_contiguous_for_memory_format,
|
||||
elementwise_dtypes,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
is_boolean_dtype,
|
||||
|
|
@ -1020,8 +1020,7 @@ def make_fast_binary_impl(
|
|||
# compute_fast_setup_type
|
||||
definitely_contiguous = True
|
||||
definitely_channels_last = True
|
||||
|
||||
# TODO: is_non-overlapping_and_dense not bound from Python
|
||||
# TODO: is_non-overlapping_and_dense (not bound from Python
|
||||
# no inplace, no out, everything defined
|
||||
|
||||
if is_noncontiguous_supported(common_device):
|
||||
|
|
@ -1030,13 +1029,13 @@ def make_fast_binary_impl(
|
|||
continue
|
||||
definitely_contiguous = (
|
||||
definitely_contiguous
|
||||
and contiguous_for_memory_format_or_false(
|
||||
and definitely_contiguous_for_memory_format(
|
||||
op, memory_format=torch.contiguous_format
|
||||
)
|
||||
)
|
||||
definitely_channels_last = (
|
||||
definitely_channels_last
|
||||
and contiguous_for_memory_format_or_false(
|
||||
and definitely_contiguous_for_memory_format(
|
||||
op, memory_format=torch.channels_last
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch
|
|||
import torch.fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import contiguous_for_memory_format_or_false
|
||||
from torch._prims_common import definitely_contiguous_for_memory_format
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.node import map_aggregate, Node
|
||||
|
|
@ -35,8 +35,8 @@ class TensorMetadata(NamedTuple):
|
|||
|
||||
# When include_contiguity is True, we will set contiguity when its always true for the tensor.
|
||||
# Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3).
|
||||
# In such situation contiguity is not set. We could also make it a tri-state i.e: (def_contiguous,
|
||||
# def_not_contiguous and unknown).
|
||||
# In such situation contiguity is not set. We could also make it a tri-state i.e: (definitely_contiguous,
|
||||
# contiguous, and unknown).
|
||||
def _extract_tensor_metadata(
|
||||
result: torch.Tensor, include_contiguity=True
|
||||
) -> TensorMetadata:
|
||||
|
|
@ -57,7 +57,7 @@ def _extract_tensor_metadata(
|
|||
torch.channels_last_3d,
|
||||
}
|
||||
for query_format in memory_formats:
|
||||
if contiguous_for_memory_format_or_false(
|
||||
if definitely_contiguous_for_memory_format(
|
||||
result, memory_format=query_format
|
||||
):
|
||||
memory_format = query_format
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user