mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Change our representation of sizes and strides to contain SymInts instead of int64_t. Right now it's not actually possible to create a Tensor with symbolic shape, so this change is intended to be a no-op. But the intended behavior is: - If you create a Tensor with symbolic shape, a `CustomSizes` policy will be set, and the `has_symbolic_sizes_strides_` bit will be set. (not currently implemented) - Calling any TensorImpl function that naively interacts with sizes and strides will throw. For hot-path functions (`sizes()`, `strides()`), we make use of the existing policy check to throw. For others, we just have a regular `TORCH_CHECK(!has_symbolic_sizes_strides_)`. This also undoes the explicit constructor I made in https://github.com/pytorch/pytorch/pull/77666; it ended up being more annoying than useful when making these changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78272 Approved by: https://github.com/Krovatkin, https://github.com/Chillee
101 lines
3.2 KiB
C++
101 lines
3.2 KiB
C++
#pragma once
|
|
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
namespace c10 {
|
|
|
|
class SymbolicIntNode;
|
|
|
|
// `SymInt` is a C++ wrapper class around int64_t data_ which and is used to
|
|
// represent concrete dimension values.
|
|
//
|
|
// `SymInt` is also a data type in Pytorch that can be used in function schemas
|
|
// to enable tracing.
|
|
//
|
|
// `SymInt` is introduced to enable tracing arithmetic
|
|
// operations on symbolic integers (e.g. sizes). Tracing symbolic sizes will
|
|
// allow LTC and AOTAutograd representing dynamic shapes in expression graphs
|
|
// faithfully without baking in concrete dimension values.
|
|
//
|
|
// To trace the operations, SymInt will overload arithmetic operators (e.g. +,
|
|
// -, *) and will provide overloads taking SymInt for commonly used math
|
|
// functions.
|
|
//
|
|
// SymInt will be extenteded to represent a union structure Union[int64_t,
|
|
// SymbolicIntNode*] which will be implemented as a single packed int64_t field
|
|
// named data_.
|
|
class C10_API SymInt {
|
|
public:
|
|
/*implicit*/ SymInt(int64_t d) : data_(d){};
|
|
SymInt() = default;
|
|
|
|
int64_t expect_int() const {
|
|
TORCH_CHECK(!is_symbolic());
|
|
return data_;
|
|
}
|
|
|
|
bool is_symbolic() const {
|
|
return (MASK & static_cast<uint64_t>(this->data_)) == IS_SYM;
|
|
}
|
|
|
|
bool operator==(const SymInt& p2) const {
|
|
return data_ == p2.data_;
|
|
}
|
|
|
|
bool operator!=(const SymInt& p2) const {
|
|
return data_ != p2.data_;
|
|
}
|
|
|
|
SymInt operator+(SymInt sci) const;
|
|
bool operator<(SymInt sci) const;
|
|
void operator*=(SymInt sci);
|
|
|
|
SymInt operator*(int64_t sci) const;
|
|
bool operator<(int64_t sci) const;
|
|
bool operator==(int64_t sci) const;
|
|
bool operator!=(int64_t sci) const;
|
|
|
|
std::shared_ptr<SymbolicIntNode> toSymbolicIntNode();
|
|
static c10::SymInt toSymInt(std::shared_ptr<SymbolicIntNode> sin);
|
|
|
|
int64_t as_int_unchecked() const {
|
|
return data_;
|
|
}
|
|
|
|
// This is needed for interoperability with IValue
|
|
int64_t data() const {
|
|
return data_;
|
|
}
|
|
|
|
// Return whether the integer is representable as a SymInt.
|
|
static bool check_range(int64_t i) {
|
|
return i > MIN_INT;
|
|
}
|
|
|
|
private:
|
|
// Constraints on the internal representation:
|
|
// - Should represent positive and negative ints
|
|
// - No conversion necessary for operations on ints.
|
|
// - We reserve some values to act as indices into our sym int table.
|
|
//
|
|
// So, the scheme is to reserve large negative numbers:
|
|
// - 0b0.... means we are a positive int (following two's complement)
|
|
// - 0b11... means we are a negative int (following two's complement)
|
|
// - 0b10... means we are index into the sym table. This means that
|
|
// [-2^63, -2^62-1] are not representable as ints.
|
|
static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62;
|
|
static constexpr uint64_t IS_SYM = 1ULL << 63;
|
|
// Since we use the top two bits to determine whether something is symbolic,
|
|
// we cannot represent symbolic indices that are large enough to use those
|
|
// bits. This will probably never happen.
|
|
static constexpr uint64_t MAX_SYM_IDX = 1ULL << 62;
|
|
// Since 0b10... is reserved for symbolic indices, any integers lower than
|
|
// this value would collide with our representation.
|
|
static constexpr int64_t MIN_INT = -1LL & ~(1ULL << 62);
|
|
int64_t data_;
|
|
};
|
|
|
|
C10_API std::ostream& operator<<(std::ostream& os, SymInt s);
|
|
} // namespace c10
|