Revert "introduce definitely_contiguous and use it for reshape and tensor meta data computation. (#153432)"

This reverts commit 5c6d7caaaa.

Reverted https://github.com/pytorch/pytorch/pull/153432 on behalf of https://github.com/malfet due to Looks like it broke flex attention tests, see https://hud.pytorch.org/hud/pytorch/pytorch/main/1?per_page=50&name_filter=g6.4xlarge&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/153432#issuecomment-2912562570))
This commit is contained in:
PyTorch MergeBot 2025-05-27 13:42:34 +00:00
parent c52a002a22
commit 11a51a11af
7 changed files with 37 additions and 161 deletions

View File

@ -24,7 +24,6 @@
#include <ATen/native/cpu/SerialStackImpl.h>
#include <ATen/native/cpu/StackKernel.h>
#include <ATen/quantized/QTensorImpl.h>
#include <c10/core/Contiguity.h>
#include <c10/core/GradMode.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
@ -1994,15 +1993,11 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
TORCH_CHECK(false, "reshape is not implemented for sparse tensors");
}
auto sym_sizes = self.sym_sizes();
auto sym_strides = self.sym_strides();
auto sym_numel = self.sym_numel();
if (definitely_contiguous(sym_sizes, sym_strides, sym_numel) &&
!self.is_mkldnn()) {
if (self.is_contiguous() && !self.is_mkldnn()) {
return self.view_symint(proposed_shape);
}
c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel);
c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel());
if (self.is_mkldnn()) {
return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape));
@ -2010,7 +2005,8 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
// `computeStride` returns the proper strides to use if this
// `reshape` can be just a view.
auto stride = at::detail::computeStride(sym_sizes, sym_strides, shape);
auto stride =
at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape);
// NB: Even though we have viewable geometry and the target strides here,
// we do not just call `as_strided` on `self` because the backward

View File

@ -12,49 +12,24 @@ namespace c10 {
template <typename T>
bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
bool is_contiguous = true;
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) {
return true;
return is_contiguous;
}
T expected_stride = 1;
T z = 1;
// NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) {
continue;
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(strides[d], z))) {
z *= size_d;
} else {
is_contiguous = false;
break;
}
}
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) {
return false;
}
expected_stride *= size_d;
}
return true;
}
// This function will return True if the tensor is contiguous, and False if the
// its not or if we can't determine if it is contiguous due to unbacked symbols
// (it could be either in that case based on the actual runtime data).
template <typename T>
bool definitely_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) {
return true;
}
T expected_stride = 1;
// NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) {
continue;
}
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) {
return false;
}
expected_stride *= size_d;
}
return true;
return is_contiguous;
}
template <typename T>

View File

@ -3281,39 +3281,6 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 1)
# Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride.
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
)
with ctx():
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].
# but not anymore since we use definitely_contiguous .
# We need a way to mark strides unbacked to avoid the recompilation here.
x = torch.randn(10, 10)
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
result_compiled = compiled_func(x, torch.tensor([2, 50]))
result_eager = func(x, torch.tensor([2, 50]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 2)
x = torch.randn(4, 4)
result_eager = func(x, torch.tensor([2, 8]))
result_compiled = compiled_func(x, torch.tensor([2, 8]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 2)
@unittest.skip("this test fails due to inductor/autograd issue #153041")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_non_contigious_reshape_failing(self):

View File

@ -1370,8 +1370,8 @@ def forward(self, crop_camera_1, mask_1):
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
mul_6 = sym_size_int * 3
view_3 = torch.ops.aten.view.default(view_2, [mul_6, 3]); view_2 = mul_6 = None
mul_4 = sym_size_int * 3
view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
_unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None

View File

@ -259,64 +259,47 @@ def check_all_strides(
# This function is equivalent to compute_contiguous() from TensorImpl.cpp
def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
def is_contiguous(a: TensorLikeType) -> bool:
"""
Tests whether a tensor is contiguous or not.
Tensors are contiguous when they have no elements,
one element, or when they have "nested" strides.
"""
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
if maybe_guard_or_false(a.numel() == 0):
if guard_size_oblivious(a.numel() < 2):
return True
expected_stride = 1
for x, y in reversed(tuple(zip(a.shape, a.stride()))):
# Skips checking strides when a dimension has length 1
if maybe_guard_or_false(x == 1):
if guard_size_oblivious(x == 1):
continue
if maybe_guard_or_true(y != expected_stride):
if guard_size_oblivious(y != expected_stride):
return False
# if x is 0 then a is contiguous anyway. So in the check above for non-contiguity condition we can
# can assume x is not 0 in expected_stride equation. This is also consistent with make_contiguous_strides_for.
expected_stride = expected_stride * sym_max(x, 1)
expected_stride = expected_stride * x
return True
# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
def is_channels_last_contiguous_2d(a: Tensor) -> bool:
# NHWC or not channels last 2D contiguous
if a.ndim != 4:
return False
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
expected_stride = 1
for idx in (1, 3, 2, 0):
length = a.shape[idx]
if maybe_guard_or_false(length == 1):
if guard_size_oblivious(length == 1):
continue
stride = a.stride()[idx]
if maybe_guard_or_true(stride != expected_stride):
if guard_size_oblivious(stride != expected_stride):
return False
expected_stride *= length
@ -324,28 +307,21 @@ def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
return True
def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool:
def is_channels_last_contiguous_3d(a: Tensor) -> bool:
# NDHWC or not channels last 3D contiguous
if a.ndim != 5:
return False
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
expected_stride = 1
for idx in (1, 4, 3, 2, 0):
length = a.shape[idx]
if maybe_guard_or_false(length == 1):
if guard_size_oblivious(length == 1):
continue
stride = a.stride()[idx]
if maybe_guard_or_true(stride != expected_stride):
if guard_size_oblivious(stride != expected_stride):
return False
expected_stride *= length
@ -369,16 +345,16 @@ def validate_memory_format(memory_format: torch.memory_format):
def is_contiguous_for_memory_format( # type: ignore[return]
a: Tensor, *, memory_format: torch.memory_format, false_if_dde=False
a: Tensor, *, memory_format: torch.memory_format
) -> bool:
validate_memory_format(memory_format)
if memory_format == torch.contiguous_format:
return is_contiguous(a, false_if_dde)
return is_contiguous(a)
if memory_format == torch.channels_last:
return is_channels_last_contiguous_2d(a, false_if_dde)
return is_channels_last_contiguous_2d(a)
if memory_format == torch.channels_last_3d:
return is_channels_last_contiguous_3d(a, false_if_dde)
return is_channels_last_contiguous_3d(a)
torch._check(
False,
@ -386,29 +362,6 @@ def is_contiguous_for_memory_format( # type: ignore[return]
)
def definitely_contiguous(a: TensorLikeType) -> bool:
return is_contiguous(a, false_if_dde=True)
# similar to is_channels_last_contiguous_2d but return false on data dependency.
def is_known_channels_last_contiguous_2d(a: Tensor) -> bool:
return is_channels_last_contiguous_2d(a, false_if_dde=True)
# similar to is_channels_last_contiguous_3d but return false on data dependency.
def is_known_channels_last_contiguous_3d(a: Tensor) -> bool:
return is_channels_last_contiguous_3d(a, false_if_dde=True)
# similar to is_contiguous_for_memory_format but return false on data dependency.
def definitely_contiguous_for_memory_format( # type: ignore[return]
a: Tensor, *, memory_format: torch.memory_format
) -> bool:
return is_contiguous_for_memory_format(
a, memory_format=memory_format, false_if_dde=True
)
# NOTE: that tensors with no elements and channels last is ???
def is_channels_last_contiguous(a: Tensor) -> bool:
"""
@ -426,13 +379,6 @@ def is_channels_last_contiguous(a: Tensor) -> bool:
return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a)
# similar to is_channels_last_contiguous but return false on data dependency.
def is_known_channels_last_contiguous(a: Tensor) -> bool:
return is_known_channels_last_contiguous_2d(
a
) or is_known_channels_last_contiguous_3d(a)
def is_non_overlapping_and_dense(a: Tensor) -> bool:
"""
True when a tensor is non-overlapping and dense.

View File

@ -19,7 +19,6 @@ import torch.utils._pytree as pytree
from torch import sym_float, sym_int
from torch._prims_common import (
BoolLike,
definitely_contiguous,
DeviceLikeType,
Dim,
DimsSequenceType,
@ -3825,7 +3824,7 @@ def _view_simple(a: TensorLikeType, shape, data_dependent_error) -> TensorLikeTy
if new_strides is not None:
return a.as_strided(shape, new_strides)
if definitely_contiguous(a):
if a.is_contiguous():
return a.as_strided(shape, utils.make_contiguous_strides_for(shape))
raise data_dependent_error

View File

@ -7,7 +7,6 @@ import torch
import torch.fx
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import detect_fake_mode
from torch._prims_common import definitely_contiguous_for_memory_format
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx._compatibility import compatibility
from torch.fx.node import map_aggregate, Node
@ -33,10 +32,6 @@ class TensorMetadata(NamedTuple):
qparams: dict[str, Any]
# When include_contiguity is True, we will set contiguity when its always true for the tensor.
# Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3).
# In such situation contiguity is not set. We could also make it a tri-state i.e: (definitely_contiguous,
# contiguous, and unknown).
def _extract_tensor_metadata(
result: torch.Tensor, include_contiguity=True
) -> TensorMetadata:
@ -57,9 +52,7 @@ def _extract_tensor_metadata(
torch.channels_last_3d,
}
for query_format in memory_formats:
if definitely_contiguous_for_memory_format(
result, memory_format=query_format
):
if result.is_contiguous(memory_format=query_format):
memory_format = query_format
break