mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
when a tensor has unbacked symbols it can be general enough to represent both contiguous and non contiguous tensors. in that case we cant really evaluate is_contiguous. In many places in the code base, we check for is_contiguous to take a fast path. but the general path usually works for both contiguous and not contiguous in that case we probably want to use definitely _contiguous API. This is appleid for reshape in this PR and also to tensor meta data computation, the meta data now will have an attribute that says that its contiguous when its always contiguous. We would store that only if definitely _contiguous is true now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153432 Approved by: https://github.com/bobrenjc93
155 lines
4.0 KiB
C++
155 lines
4.0 KiB
C++
#pragma once
|
|
#include <c10/core/SymBool.h>
|
|
#include <c10/core/SymInt.h>
|
|
#include <c10/util/ArrayRef.h>
|
|
#include <c10/util/SmallVector.h>
|
|
#include <c10/util/irange.h>
|
|
|
|
#include <algorithm>
|
|
#include <cstdint>
|
|
|
|
namespace c10 {
|
|
|
|
template <typename T>
|
|
bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
|
|
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) {
|
|
return true;
|
|
}
|
|
|
|
T expected_stride = 1;
|
|
// NB: make sure we do signed arithmetic
|
|
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
|
|
const auto& size_d = sizes[d];
|
|
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) {
|
|
continue;
|
|
}
|
|
|
|
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) {
|
|
return false;
|
|
}
|
|
expected_stride *= size_d;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// This function will return True if the tensor is contiguous, and False if the
|
|
// its not or if we can't determine if it is contiguous due to unbacked symbols
|
|
// (it could be either in that case based on the actual runtime data).
|
|
template <typename T>
|
|
bool definitely_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
|
|
if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) {
|
|
return true;
|
|
}
|
|
|
|
T expected_stride = 1;
|
|
// NB: make sure we do signed arithmetic
|
|
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
|
|
const auto& size_d = sizes[d];
|
|
if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) {
|
|
continue;
|
|
}
|
|
|
|
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) {
|
|
return false;
|
|
}
|
|
expected_stride *= size_d;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <typename T>
|
|
bool _compute_channels_last_contiguous_2d(
|
|
ArrayRef<T> sizes,
|
|
ArrayRef<T> strides) {
|
|
// Please don't combine these code, constant array is used here to let
|
|
// compiler fully unroll the loop to get better performance
|
|
switch (sizes.size()) {
|
|
case 4: {
|
|
T expected = 1;
|
|
for (auto& d : {1, 3, 2, 0}) {
|
|
const auto& size_d = sizes[d];
|
|
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
|
|
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
|
|
return false;
|
|
}
|
|
expected *= size_d;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
case 3:
|
|
// TODO dim == 3 case will be enabled once it is fully tested
|
|
return false;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
bool _compute_channels_last_contiguous_3d(
|
|
ArrayRef<T> sizes,
|
|
ArrayRef<T> strides) {
|
|
// Please don't combine these code, constant array is used here to let
|
|
// compiler fully unroll the loop to get better performance
|
|
switch (sizes.size()) {
|
|
case 5: {
|
|
T expected = 1;
|
|
for (auto& d : {1, 4, 3, 2, 0}) {
|
|
const auto& size_d = sizes[d];
|
|
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
|
|
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
|
|
return false;
|
|
}
|
|
expected *= size_d;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
case 4:
|
|
// TODO dim == 4 case will be enabled once it is fully tested
|
|
return false;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
bool _compute_non_overlapping_and_dense(
|
|
ArrayRef<T> sizes,
|
|
ArrayRef<T> strides) {
|
|
auto dim = sizes.size();
|
|
if (dim == 1) {
|
|
return sizes[0] < 2 || strides[0] == 1;
|
|
}
|
|
SmallVector<int64_t, 5> perm;
|
|
perm.resize(dim);
|
|
for (const auto i : c10::irange(dim)) {
|
|
perm[i] = i;
|
|
}
|
|
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
|
|
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
|
|
if (sizes[a] < 2) {
|
|
return false;
|
|
} else if (sizes[b] < 2) {
|
|
return true;
|
|
}
|
|
return strides[a] < strides[b];
|
|
});
|
|
T require_stride = 1;
|
|
for (const auto i : c10::irange(dim)) {
|
|
const auto& size_perm_i = sizes[perm[i]];
|
|
if (size_perm_i < 2) {
|
|
return true;
|
|
}
|
|
if (strides[perm[i]] != require_stride) {
|
|
return false;
|
|
}
|
|
require_stride *= size_perm_i;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // namespace c10
|