pytorch/c10/core/Contiguity.h
Laith Sakka 39df901b2a introduce definitely_contiguous and use it for reshape and tensor meta data computation. (#153432)
when a tensor has unbacked symbols it can be general enough to represent both contiguous and non contiguous tensors.
in that case we cant really evaluate is_contiguous. In many places in the code base, we check for is_contiguous to take a fast path. but the general path usually works for both contiguous and not contiguous in that case we probably want
to use definitely _contiguous API.

This is appleid for reshape in this PR and also to  tensor meta data computation, the meta data now will have an attribute that says that its contiguous when its always contiguous. We would store that only if definitely _contiguous is true  now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153432
Approved by: https://github.com/bobrenjc93
2025-05-28 03:41:26 +00:00

155 lines
4.0 KiB
C++

#pragma once
#include <c10/core/SymBool.h>
#include <c10/core/SymInt.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/SmallVector.h>
#include <c10/util/irange.h>
#include <algorithm>
#include <cstdint>
namespace c10 {
template <typename T>
bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) {
return true;
}
T expected_stride = 1;
// NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) {
continue;
}
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) {
return false;
}
expected_stride *= size_d;
}
return true;
}
// This function will return True if the tensor is contiguous, and False if the
// its not or if we can't determine if it is contiguous due to unbacked symbols
// (it could be either in that case based on the actual runtime data).
template <typename T>
bool definitely_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) {
return true;
}
T expected_stride = 1;
// NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) {
continue;
}
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) {
return false;
}
expected_stride *= size_d;
}
return true;
}
template <typename T>
bool _compute_channels_last_contiguous_2d(
ArrayRef<T> sizes,
ArrayRef<T> strides) {
// Please don't combine these code, constant array is used here to let
// compiler fully unroll the loop to get better performance
switch (sizes.size()) {
case 4: {
T expected = 1;
for (auto& d : {1, 3, 2, 0}) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
return false;
}
expected *= size_d;
}
}
return true;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
case 3:
// TODO dim == 3 case will be enabled once it is fully tested
return false;
default:
return false;
}
}
template <typename T>
bool _compute_channels_last_contiguous_3d(
ArrayRef<T> sizes,
ArrayRef<T> strides) {
// Please don't combine these code, constant array is used here to let
// compiler fully unroll the loop to get better performance
switch (sizes.size()) {
case 5: {
T expected = 1;
for (auto& d : {1, 4, 3, 2, 0}) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
return false;
}
expected *= size_d;
}
}
return true;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
case 4:
// TODO dim == 4 case will be enabled once it is fully tested
return false;
default:
return false;
}
}
template <typename T>
bool _compute_non_overlapping_and_dense(
ArrayRef<T> sizes,
ArrayRef<T> strides) {
auto dim = sizes.size();
if (dim == 1) {
return sizes[0] < 2 || strides[0] == 1;
}
SmallVector<int64_t, 5> perm;
perm.resize(dim);
for (const auto i : c10::irange(dim)) {
perm[i] = i;
}
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
if (sizes[a] < 2) {
return false;
} else if (sizes[b] < 2) {
return true;
}
return strides[a] < strides[b];
});
T require_stride = 1;
for (const auto i : c10::irange(dim)) {
const auto& size_perm_i = sizes[perm[i]];
if (size_perm_i < 2) {
return true;
}
if (strides[perm[i]] != require_stride) {
return false;
}
require_stride *= size_perm_i;
}
return true;
}
} // namespace c10