mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "introduce definitely_contiguous and use it for reshape and tensor meta data computation. (#153432)"
This reverts commit 5c6d7caaaa.
Reverted https://github.com/pytorch/pytorch/pull/153432 on behalf of https://github.com/malfet due to Looks like it broke flex attention tests, see https://hud.pytorch.org/hud/pytorch/pytorch/main/1?per_page=50&name_filter=g6.4xlarge&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/153432#issuecomment-2912562570))
This commit is contained in:
parent
c52a002a22
commit
11a51a11af
|
|
@ -24,7 +24,6 @@
|
|||
#include <ATen/native/cpu/SerialStackImpl.h>
|
||||
#include <ATen/native/cpu/StackKernel.h>
|
||||
#include <ATen/quantized/QTensorImpl.h>
|
||||
#include <c10/core/Contiguity.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
|
|
@ -1994,15 +1993,11 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
|
|||
TORCH_CHECK(false, "reshape is not implemented for sparse tensors");
|
||||
}
|
||||
|
||||
auto sym_sizes = self.sym_sizes();
|
||||
auto sym_strides = self.sym_strides();
|
||||
auto sym_numel = self.sym_numel();
|
||||
if (definitely_contiguous(sym_sizes, sym_strides, sym_numel) &&
|
||||
!self.is_mkldnn()) {
|
||||
if (self.is_contiguous() && !self.is_mkldnn()) {
|
||||
return self.view_symint(proposed_shape);
|
||||
}
|
||||
|
||||
c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel);
|
||||
c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel());
|
||||
|
||||
if (self.is_mkldnn()) {
|
||||
return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape));
|
||||
|
|
@ -2010,7 +2005,8 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
|
|||
|
||||
// `computeStride` returns the proper strides to use if this
|
||||
// `reshape` can be just a view.
|
||||
auto stride = at::detail::computeStride(sym_sizes, sym_strides, shape);
|
||||
auto stride =
|
||||
at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape);
|
||||
|
||||
// NB: Even though we have viewable geometry and the target strides here,
|
||||
// we do not just call `as_strided` on `self` because the backward
|
||||
|
|
|
|||
|
|
@ -12,49 +12,24 @@ namespace c10 {
|
|||
|
||||
template <typename T>
|
||||
bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
|
||||
bool is_contiguous = true;
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) {
|
||||
return true;
|
||||
return is_contiguous;
|
||||
}
|
||||
|
||||
T expected_stride = 1;
|
||||
T z = 1;
|
||||
// NB: make sure we do signed arithmetic
|
||||
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
|
||||
const auto& size_d = sizes[d];
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) {
|
||||
continue;
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(strides[d], z))) {
|
||||
z *= size_d;
|
||||
} else {
|
||||
is_contiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) {
|
||||
return false;
|
||||
}
|
||||
expected_stride *= size_d;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// This function will return True if the tensor is contiguous, and False if the
|
||||
// its not or if we can't determine if it is contiguous due to unbacked symbols
|
||||
// (it could be either in that case based on the actual runtime data).
|
||||
template <typename T>
|
||||
bool definitely_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
|
||||
if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
T expected_stride = 1;
|
||||
// NB: make sure we do signed arithmetic
|
||||
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
|
||||
const auto& size_d = sizes[d];
|
||||
if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) {
|
||||
return false;
|
||||
}
|
||||
expected_stride *= size_d;
|
||||
}
|
||||
return true;
|
||||
return is_contiguous;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -3281,39 +3281,6 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
|||
self.assertEqual(result_compiled, result_eager)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
# Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride.
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
||||
)
|
||||
with ctx():
|
||||
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].
|
||||
# but not anymore since we use definitely_contiguous .
|
||||
# We need a way to mark strides unbacked to avoid the recompilation here.
|
||||
x = torch.randn(10, 10)
|
||||
torch._dynamo.decorators.mark_unbacked(x, 0)
|
||||
torch._dynamo.decorators.mark_unbacked(x, 1)
|
||||
|
||||
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
||||
self.assertExpectedInline(
|
||||
aot_graphs,
|
||||
"""""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
result_compiled = compiled_func(x, torch.tensor([2, 50]))
|
||||
result_eager = func(x, torch.tensor([2, 50]))
|
||||
|
||||
self.assertEqual(result_compiled, result_eager)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
x = torch.randn(4, 4)
|
||||
|
||||
result_eager = func(x, torch.tensor([2, 8]))
|
||||
result_compiled = compiled_func(x, torch.tensor([2, 8]))
|
||||
self.assertEqual(result_compiled, result_eager)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
@unittest.skip("this test fails due to inductor/autograd issue #153041")
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_non_contigious_reshape_failing(self):
|
||||
|
|
|
|||
|
|
@ -1370,8 +1370,8 @@ def forward(self, crop_camera_1, mask_1):
|
|||
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
|
||||
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
|
||||
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
|
||||
mul_6 = sym_size_int * 3
|
||||
view_3 = torch.ops.aten.view.default(view_2, [mul_6, 3]); view_2 = mul_6 = None
|
||||
mul_4 = sym_size_int * 3
|
||||
view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None
|
||||
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
|
||||
_unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
|
||||
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None
|
||||
|
|
|
|||
|
|
@ -259,64 +259,47 @@ def check_all_strides(
|
|||
|
||||
|
||||
# This function is equivalent to compute_contiguous() from TensorImpl.cpp
|
||||
def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
|
||||
def is_contiguous(a: TensorLikeType) -> bool:
|
||||
"""
|
||||
Tests whether a tensor is contiguous or not.
|
||||
|
||||
Tensors are contiguous when they have no elements,
|
||||
one element, or when they have "nested" strides.
|
||||
"""
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
|
||||
if maybe_guard_or_false(a.numel() == 0):
|
||||
if guard_size_oblivious(a.numel() < 2):
|
||||
return True
|
||||
|
||||
expected_stride = 1
|
||||
for x, y in reversed(tuple(zip(a.shape, a.stride()))):
|
||||
# Skips checking strides when a dimension has length 1
|
||||
if maybe_guard_or_false(x == 1):
|
||||
if guard_size_oblivious(x == 1):
|
||||
continue
|
||||
|
||||
if maybe_guard_or_true(y != expected_stride):
|
||||
if guard_size_oblivious(y != expected_stride):
|
||||
return False
|
||||
|
||||
# if x is 0 then a is contiguous anyway. So in the check above for non-contiguity condition we can
|
||||
# can assume x is not 0 in expected_stride equation. This is also consistent with make_contiguous_strides_for.
|
||||
expected_stride = expected_stride * sym_max(x, 1)
|
||||
expected_stride = expected_stride * x
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
|
||||
def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
|
||||
def is_channels_last_contiguous_2d(a: Tensor) -> bool:
|
||||
# NHWC or not channels last 2D contiguous
|
||||
if a.ndim != 4:
|
||||
return False
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
expected_stride = 1
|
||||
for idx in (1, 3, 2, 0):
|
||||
length = a.shape[idx]
|
||||
if maybe_guard_or_false(length == 1):
|
||||
if guard_size_oblivious(length == 1):
|
||||
continue
|
||||
|
||||
stride = a.stride()[idx]
|
||||
if maybe_guard_or_true(stride != expected_stride):
|
||||
if guard_size_oblivious(stride != expected_stride):
|
||||
return False
|
||||
|
||||
expected_stride *= length
|
||||
|
|
@ -324,28 +307,21 @@ def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool:
|
||||
def is_channels_last_contiguous_3d(a: Tensor) -> bool:
|
||||
# NDHWC or not channels last 3D contiguous
|
||||
if a.ndim != 5:
|
||||
return False
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
expected_stride = 1
|
||||
for idx in (1, 4, 3, 2, 0):
|
||||
length = a.shape[idx]
|
||||
if maybe_guard_or_false(length == 1):
|
||||
if guard_size_oblivious(length == 1):
|
||||
continue
|
||||
|
||||
stride = a.stride()[idx]
|
||||
if maybe_guard_or_true(stride != expected_stride):
|
||||
if guard_size_oblivious(stride != expected_stride):
|
||||
return False
|
||||
|
||||
expected_stride *= length
|
||||
|
|
@ -369,16 +345,16 @@ def validate_memory_format(memory_format: torch.memory_format):
|
|||
|
||||
|
||||
def is_contiguous_for_memory_format( # type: ignore[return]
|
||||
a: Tensor, *, memory_format: torch.memory_format, false_if_dde=False
|
||||
a: Tensor, *, memory_format: torch.memory_format
|
||||
) -> bool:
|
||||
validate_memory_format(memory_format)
|
||||
|
||||
if memory_format == torch.contiguous_format:
|
||||
return is_contiguous(a, false_if_dde)
|
||||
return is_contiguous(a)
|
||||
if memory_format == torch.channels_last:
|
||||
return is_channels_last_contiguous_2d(a, false_if_dde)
|
||||
return is_channels_last_contiguous_2d(a)
|
||||
if memory_format == torch.channels_last_3d:
|
||||
return is_channels_last_contiguous_3d(a, false_if_dde)
|
||||
return is_channels_last_contiguous_3d(a)
|
||||
|
||||
torch._check(
|
||||
False,
|
||||
|
|
@ -386,29 +362,6 @@ def is_contiguous_for_memory_format( # type: ignore[return]
|
|||
)
|
||||
|
||||
|
||||
def definitely_contiguous(a: TensorLikeType) -> bool:
|
||||
return is_contiguous(a, false_if_dde=True)
|
||||
|
||||
|
||||
# similar to is_channels_last_contiguous_2d but return false on data dependency.
|
||||
def is_known_channels_last_contiguous_2d(a: Tensor) -> bool:
|
||||
return is_channels_last_contiguous_2d(a, false_if_dde=True)
|
||||
|
||||
|
||||
# similar to is_channels_last_contiguous_3d but return false on data dependency.
|
||||
def is_known_channels_last_contiguous_3d(a: Tensor) -> bool:
|
||||
return is_channels_last_contiguous_3d(a, false_if_dde=True)
|
||||
|
||||
|
||||
# similar to is_contiguous_for_memory_format but return false on data dependency.
|
||||
def definitely_contiguous_for_memory_format( # type: ignore[return]
|
||||
a: Tensor, *, memory_format: torch.memory_format
|
||||
) -> bool:
|
||||
return is_contiguous_for_memory_format(
|
||||
a, memory_format=memory_format, false_if_dde=True
|
||||
)
|
||||
|
||||
|
||||
# NOTE: that tensors with no elements and channels last is ???
|
||||
def is_channels_last_contiguous(a: Tensor) -> bool:
|
||||
"""
|
||||
|
|
@ -426,13 +379,6 @@ def is_channels_last_contiguous(a: Tensor) -> bool:
|
|||
return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a)
|
||||
|
||||
|
||||
# similar to is_channels_last_contiguous but return false on data dependency.
|
||||
def is_known_channels_last_contiguous(a: Tensor) -> bool:
|
||||
return is_known_channels_last_contiguous_2d(
|
||||
a
|
||||
) or is_known_channels_last_contiguous_3d(a)
|
||||
|
||||
|
||||
def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
||||
"""
|
||||
True when a tensor is non-overlapping and dense.
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ import torch.utils._pytree as pytree
|
|||
from torch import sym_float, sym_int
|
||||
from torch._prims_common import (
|
||||
BoolLike,
|
||||
definitely_contiguous,
|
||||
DeviceLikeType,
|
||||
Dim,
|
||||
DimsSequenceType,
|
||||
|
|
@ -3825,7 +3824,7 @@ def _view_simple(a: TensorLikeType, shape, data_dependent_error) -> TensorLikeTy
|
|||
if new_strides is not None:
|
||||
return a.as_strided(shape, new_strides)
|
||||
|
||||
if definitely_contiguous(a):
|
||||
if a.is_contiguous():
|
||||
return a.as_strided(shape, utils.make_contiguous_strides_for(shape))
|
||||
|
||||
raise data_dependent_error
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import torch
|
|||
import torch.fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import definitely_contiguous_for_memory_format
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.node import map_aggregate, Node
|
||||
|
|
@ -33,10 +32,6 @@ class TensorMetadata(NamedTuple):
|
|||
qparams: dict[str, Any]
|
||||
|
||||
|
||||
# When include_contiguity is True, we will set contiguity when its always true for the tensor.
|
||||
# Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3).
|
||||
# In such situation contiguity is not set. We could also make it a tri-state i.e: (definitely_contiguous,
|
||||
# contiguous, and unknown).
|
||||
def _extract_tensor_metadata(
|
||||
result: torch.Tensor, include_contiguity=True
|
||||
) -> TensorMetadata:
|
||||
|
|
@ -57,9 +52,7 @@ def _extract_tensor_metadata(
|
|||
torch.channels_last_3d,
|
||||
}
|
||||
for query_format in memory_formats:
|
||||
if definitely_contiguous_for_memory_format(
|
||||
result, memory_format=query_format
|
||||
):
|
||||
if result.is_contiguous(memory_format=query_format):
|
||||
memory_format = query_format
|
||||
break
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user