mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Change our representation of sizes and strides to contain SymInts instead of int64_t. Right now it's not actually possible to create a Tensor with symbolic shape, so this change is intended to be a no-op. But the intended behavior is: - If you create a Tensor with symbolic shape, a `CustomSizes` policy will be set, and the `has_symbolic_sizes_strides_` bit will be set. (not currently implemented) - Calling any TensorImpl function that naively interacts with sizes and strides will throw. For hot-path functions (`sizes()`, `strides()`), we make use of the existing policy check to throw. For others, we just have a regular `TORCH_CHECK(!has_symbolic_sizes_strides_)`. This also undoes the explicit constructor I made in https://github.com/pytorch/pytorch/pull/77666; it ended up being more annoying than useful when making these changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78272 Approved by: https://github.com/Krovatkin, https://github.com/Chillee
196 lines
6.1 KiB
C++
196 lines
6.1 KiB
C++
// This file defines `SymIntArrayRef` which serves as the view onto
|
|
// std::vector<SymInt>. This class is conceptually and mostly functionally
|
|
// equivalent to ArrayRef<SymInt>.
|
|
//
|
|
// However, ArrayRef<SymInt> can't be used directly as it introduces ambiguity
|
|
// in the following cases:
|
|
// - a.expand({1, 2, 3}) matches two overloads:
|
|
// 1. `at::Tensor Tensor::expand(c10::SymIntArrayRef size, bool implicit)`
|
|
// 2. `at::Tensor Tensor::expand(at::IntArrayRef size, bool implicit)`
|
|
// Introducing `SymIntArrayRef` allows to have a finer-grained control over
|
|
// which overload will be used.
|
|
|
|
#pragma once
|
|
|
|
#include <c10/core/SymInt.h>
|
|
#include <c10/util/ArrayRef.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <array>
|
|
#include <initializer_list>
|
|
#include <iterator>
|
|
#include <vector>
|
|
|
|
namespace c10 {
|
|
/// SymIntArrayRef - Represent a constant reference to an array (0 or more
|
|
/// elements consecutively in memory), i.e. a start pointer and a length. It
|
|
/// allows various APIs to take consecutive elements easily and conveniently.
|
|
///
|
|
/// This class does not own the underlying data, it is expected to be used in
|
|
/// situations where the data resides in some other buffer, whose lifetime
|
|
/// extends past that of the SymIntArrayRef. For this reason, it is not in
|
|
/// general safe to store an SymIntArrayRef.
|
|
///
|
|
/// This is intended to be trivially copyable, so it should be passed by
|
|
/// value.
|
|
|
|
class SymIntArrayRef final {
|
|
public:
|
|
using iterator = const c10::SymInt*;
|
|
using const_iterator = const c10::SymInt*;
|
|
using size_type = size_t;
|
|
using value_type = c10::SymInt;
|
|
|
|
using reverse_iterator = std::reverse_iterator<iterator>;
|
|
|
|
private:
|
|
ArrayRef<c10::SymInt> wrapped_symint_array_ref;
|
|
|
|
public:
|
|
/// @name Constructors
|
|
/// @{
|
|
|
|
/// Construct an empty SymIntArrayRef.
|
|
/* implicit */ constexpr SymIntArrayRef() {}
|
|
|
|
/* implicit */ SymIntArrayRef(const std::vector<c10::SymInt>& Vec)
|
|
: wrapped_symint_array_ref(Vec) {}
|
|
|
|
/// Construct an SymIntArrayRef from a pointer and length.
|
|
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
|
|
const c10::SymInt* data,
|
|
size_t length)
|
|
: wrapped_symint_array_ref(data, length) {}
|
|
|
|
/// Construct an SymIntArrayRef from a range.
|
|
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
|
|
const c10::SymInt* begin,
|
|
const c10::SymInt* end)
|
|
: wrapped_symint_array_ref(begin, end) {}
|
|
|
|
/// Construct an SymIntArrayRef from a C array.
|
|
template <size_t N>
|
|
/* implicit */ constexpr SymIntArrayRef(const c10::SymInt (&Arr)[N])
|
|
: wrapped_symint_array_ref(Arr) {}
|
|
|
|
static SymIntArrayRef fromIntArrayRef(IntArrayRef array_ref) {
|
|
for (size_t i = 0; i < array_ref.size(); ++i) {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
SymInt::check_range(array_ref[i]),
|
|
"IntArrayRef contains int that cannot be representative as a SymInt",
|
|
array_ref[i]);
|
|
}
|
|
return SymIntArrayRef(
|
|
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
|
|
}
|
|
|
|
/// @}
|
|
/// @name Simple Operations
|
|
/// @{
|
|
|
|
constexpr iterator begin() const {
|
|
return wrapped_symint_array_ref.begin();
|
|
}
|
|
constexpr iterator end() const {
|
|
return wrapped_symint_array_ref.end();
|
|
}
|
|
|
|
// These are actually the same as iterator, since SymIntArrayRef only
|
|
// gives you const iterators.
|
|
constexpr const_iterator cbegin() const {
|
|
return wrapped_symint_array_ref.cbegin();
|
|
}
|
|
constexpr const_iterator cend() const {
|
|
return wrapped_symint_array_ref.cend();
|
|
}
|
|
|
|
/// empty - Check if the array is empty.
|
|
constexpr bool empty() const {
|
|
return size() == 0;
|
|
}
|
|
|
|
constexpr const c10::SymInt* data() const {
|
|
return wrapped_symint_array_ref.data();
|
|
}
|
|
|
|
/// size - Get the array size.
|
|
constexpr size_t size() const {
|
|
return wrapped_symint_array_ref.size();
|
|
}
|
|
|
|
/// front - Get the first element.
|
|
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& front() const {
|
|
return wrapped_symint_array_ref.front();
|
|
}
|
|
|
|
/// back - Get the last element.
|
|
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& back() const {
|
|
return wrapped_symint_array_ref.back();
|
|
}
|
|
|
|
/// equals - Check for element-wise equality.
|
|
constexpr bool equals(SymIntArrayRef RHS) const {
|
|
return this->wrapped_symint_array_ref.equals(RHS.wrapped_symint_array_ref);
|
|
}
|
|
|
|
/// slice(n, m) - Take M elements of the array starting at element N
|
|
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef
|
|
slice(size_t N, size_t M) const {
|
|
return SymIntArrayRef(wrapped_symint_array_ref.data() + N, M);
|
|
}
|
|
|
|
/// slice(n) - Chop off the first N elements of the array.
|
|
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef slice(size_t N) const {
|
|
return slice(N, size() - N);
|
|
}
|
|
|
|
/// @}
|
|
/// @name Operator Overloads
|
|
/// @{
|
|
constexpr const c10::SymInt& operator[](size_t Index) const {
|
|
return wrapped_symint_array_ref[Index];
|
|
}
|
|
|
|
/// Vector compatibility
|
|
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& at(size_t Index) const {
|
|
return wrapped_symint_array_ref.at(Index);
|
|
}
|
|
|
|
/// Disallow accidental assignment from a temporary.
|
|
///
|
|
/// The declaration here is extra complicated so that "arrayRef = {}"
|
|
/// continues to select the move assignment operator.
|
|
template <typename U>
|
|
typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>::
|
|
type&
|
|
operator=(U&& Temporary) = delete;
|
|
|
|
/// Disallow accidental assignment from a temporary.
|
|
///
|
|
/// The declaration here is extra complicated so that "arrayRef = {}"
|
|
/// continues to select the move assignment operator.
|
|
template <typename U>
|
|
typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>::
|
|
type&
|
|
operator=(std::initializer_list<U>) = delete;
|
|
|
|
/// @}
|
|
/// @name Expensive Operations
|
|
/// @{
|
|
std::vector<c10::SymInt> vec() const {
|
|
return wrapped_symint_array_ref.vec();
|
|
}
|
|
|
|
friend std::ostream& operator<<(
|
|
std::ostream& out,
|
|
const SymIntArrayRef& list);
|
|
/// @}
|
|
};
|
|
|
|
TORCH_API at::IntArrayRef asIntArrayRefSlow(c10::SymIntArrayRef ar);
|
|
TORCH_API at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar);
|
|
|
|
std::ostream& operator<<(std::ostream& out, const c10::SymIntArrayRef& list);
|
|
|
|
} // namespace c10
|