#pragma once #include #include #include #include #include #include #include #include #include #include #include #include namespace c10 { class SymFloat; // SymInt represents either a regular int64_t, or a symbolic integer // (represented in a type erased way as SymNode). The intention is for SymInt // to represent symbolic sizes that arise when doing shape computation in // operator kernels. This allows for tracing through programs without baking in // concrete sizes into kernel calls. // // SymInt has an API equivalent to int64_t. In particular, it is a value type. // Internally, SymInt is represented in a clever packed way, so that it only // occupies one word of space; but morally, it is a union between an int64_t // and an intrusive pointer to SymNodeImpl. // // Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where // is_int() returns true class C10_API SymInt { public: enum Unchecked { UNCHECKED, }; /*implicit*/ SymInt(int64_t d) : data_(d) { if (is_heap_allocated()) { // Large negative number, heap allocate it promote_to_negative(); } } SymInt() : data_(0) {} SymInt(SymNode n); // unchecked c-tor accepting raw `data_` // One appropriate use for this is when you are constructing a symint // in a situation where you know it is non-negative (or, if it is negative, // the negative value is -1; i.e., not user controlled) SymInt(Unchecked, int64_t d) : data_(d) {} // TODO: these implementations are not optimal because they allocate a // temporary and then use the move constructor/assignment SymInt(const SymInt& s) : data_(0) { if (s.is_heap_allocated()) { *this = SymInt(s.toSymNode()); } else { data_ = s.data_; } } SymInt(SymInt&& s) noexcept : data_(s.data_) { s.data_ = 0; } SymInt& operator=(const SymInt& s) { if (this != &s) { if (s.is_heap_allocated()) { *this = SymInt(s.toSymNode()); } else { data_ = s.data_; } } return *this; } SymInt& operator=(SymInt&& s) noexcept { if (this != &s) { release_(); // release the current SymNode if any data_ = s.data_; if (s.is_heap_allocated()) s.data_ = 0; }; return *this; } SymNodeImpl* toSymNodeImplUnowned() const { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_heap_allocated()); uint64_t unextended_bits = static_cast(data_) & ~MASK; uint64_t sign_bit_mask = 1ULL << (62 - 1); // https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask; return static_cast( // NOLINTNEXTLINE(performance-no-int-to-ptr, bugprone*) reinterpret_cast(static_cast(extended_bits))); } void release_() { if (is_heap_allocated()) { SymNode::reclaim(toSymNodeImplUnowned()); // steal } } SymNodeImpl* release() && { #ifndef C10_MOBILE TORCH_INTERNAL_ASSERT(is_heap_allocated()); auto* r = toSymNodeImplUnowned(); data_ = 0; // transfer ownership return r; #else TORCH_INTERNAL_ASSERT(false); #endif } // Only valid if is_heap_allocated() SymNode toSymNode() const; // Guaranteed to return a SymNode, wrapping using base if necessary SymNode wrap_node(const SymNode& base) const; ~SymInt() { release_(); } // Require the int to be non-symbolic, and if it is symbolic raise an // error. This is safe to use for C++ code that doesn't work for symbolic // shapes, and you don't have time to fix it immediately, as if we // try to trigger the path in C++ you'll appropriately get an error int64_t expect_int() const { if (auto r = maybe_as_int()) { return *r; } TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE( false, "when unpacking SymInt, expected int but got ", *this); } // Test if we have a hint for this int (e.g., guard_int would work). // Most of the time this is true; it is only false when you have // an unbacked SymInt. bool has_hint() const; // Insert a guard for the int to be its concrete value, and then return // that value. This operation always works, even if the int is symbolic, // so long as we know what the underlying value is (e.g., this won't work // if you call it on the size of nonzero output). Don't blindly put this // everywhere; you can cause overspecialization of PyTorch programs with // this method. // // It should be called as guard_int(__FILE__, __LINE__). The file and line // number can be used to diagnose overspecialization. int64_t guard_int(const char* file, int64_t line) const; // Insert a guard that this SymInt must be size-like, returning true if // the integer actually is >= 0. Unlike manually performing a >= 0 test, // if the SymInt in question is an unbacked SymInt (or, potentially in the // future, if it contains unbacked SymInts), we will also treat the // unbacked SymInt as statically testing >= 2 (which will prevent us from // choking on, e.g., contiguity checks.) bool expect_size(const char* file, int64_t line) const; // Distinguish actual symbolic values from constants stored on the heap bool is_symbolic() const { return is_heap_allocated() && !toSymNodeImplUnowned()->constant_int().has_value(); } // N.B. It's important to keep this definition in the header // as we expect if checks to be folded for mobile builds // where `is_heap_allocated` is always false and optimize dead code paths C10_ALWAYS_INLINE bool is_heap_allocated() const { #ifdef C10_MOBILE return false; #else return !check_range(data_); #endif } SymInt operator+(const SymInt& sci) const; SymInt operator-(const SymInt& sci) const; SymInt operator*(const SymInt& sci) const; SymInt operator/(const SymInt& sci) const; SymInt operator%(const SymInt& sci) const; void operator*=(const SymInt& sci); void operator+=(const SymInt& sci); void operator/=(const SymInt& sci); SymInt clone() const; SymBool sym_eq(const SymInt&) const; SymBool sym_ne(const SymInt&) const; SymBool sym_lt(const SymInt&) const; SymBool sym_le(const SymInt&) const; SymBool sym_gt(const SymInt&) const; SymBool sym_ge(const SymInt&) const; bool operator==(const SymInt& o) const { return sym_eq(o).guard_bool(__FILE__, __LINE__); } bool operator!=(const SymInt& o) const { return sym_ne(o).guard_bool(__FILE__, __LINE__); } bool operator<(const SymInt& o) const { return sym_lt(o).guard_bool(__FILE__, __LINE__); } bool operator<=(const SymInt& o) const { return sym_le(o).guard_bool(__FILE__, __LINE__); } bool operator>(const SymInt& o) const { return sym_gt(o).guard_bool(__FILE__, __LINE__); } bool operator>=(const SymInt& o) const { return sym_ge(o).guard_bool(__FILE__, __LINE__); } SymInt min(const SymInt& sci) const; SymInt max(const SymInt& sci) const; // If both are symbolic, this checks if // they share the same node. // If both are not symbolic this just checks normal equality. bool is_same(const SymInt& other) const; operator SymFloat() const; void unsafe_set_data(size_t nbytes) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_heap_allocated()); data_ = static_cast(nbytes); } // Don't use this. Prefer maybe_as_int instead int64_t as_int_unchecked() const { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_heap_allocated()); return data_; } std::optional maybe_as_int() const { if (!is_heap_allocated()) { return data_; } auto* node = toSymNodeImplUnowned(); if (auto c = node->constant_int()) { return c; } return node->maybe_as_int(); } // Return whether the integer is directly coercible to a SymInt // without requiring heap allocation. You don't need to use this // to check if you can pass an integer to SymInt; this is guaranteed // to work (it just might heap allocate!) static bool check_range(int64_t i) { return i > MAX_UNREPRESENTABLE_INT; } // Return the min representable integer as a SymInt without // heap allocation. For quantities that count bytes (or larger), // this is still much larger than you need, so you may consider // using this as a more efficient version of MIN_INT static constexpr int64_t min_representable_int() { return MAX_UNREPRESENTABLE_INT + 1; } private: void promote_to_negative(); // Constraints on the internal representation: // // - Should represent positive and small negative ints // - No conversion necessary for operations on ints // - Must represent valid 64-bit pointers // - Is symbolic test should be FAST (two arithmetic instructions is too // much). // This code being a hotpath is based on Strobelight profiles of // is_heap_allocated(). FB only: https://fburl.com/strobelight/5l50ncxd // (you will need to change the time window). // // So, the scheme is to reserve large negative numbers (assuming // two's complement): // // - 0b0.... means we are a positive int // - 0b11... means we are a small negative int // - 0b10... means we are are a pointer. This means that // [-2^63, -2^62-1] are not representable as ints. // We don't actually need all of this space as on x86_64 // as the top 16bits aren't used for anything static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62 | 1ULL << 61; static constexpr uint64_t IS_SYM = 1ULL << 63 | 1ULL << 61; // We must manually translate the bit pattern test into a greater // than test because compiler doesn't figure it out: // https://godbolt.org/z/356aferaW static constexpr int64_t MAX_UNREPRESENTABLE_INT = -1LL & static_cast(~(1ULL << 62)); int64_t data_; }; /// Sum of a list of SymInt; accumulates into the c10::SymInt expression template < typename C, typename std::enable_if_t< std::is_same_v, int> = 0> inline c10::SymInt multiply_integers(const C& container) { return std::accumulate( container.begin(), container.end(), c10::SymInt(1), [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); } template < typename Iter, typename = std::enable_if_t::value_type, c10::SymInt>>> inline c10::SymInt multiply_integers(Iter begin, Iter end) { return std::accumulate( begin, end, c10::SymInt(1), [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); } #define DECLARE_SYMINT_OP_INTONLY(scalar_t, RetTy) \ C10_API RetTy operator%(const SymInt& a, scalar_t b); \ C10_API RetTy operator%(scalar_t a, const SymInt& b); #define DECLARE_SYMINT_OP(scalar_t, RetTy) \ C10_API RetTy operator+(const SymInt& a, scalar_t b); \ C10_API RetTy operator-(const SymInt& a, scalar_t b); \ C10_API RetTy operator*(const SymInt& a, scalar_t b); \ C10_API RetTy operator/(const SymInt& a, scalar_t b); \ C10_API RetTy operator+(scalar_t a, const SymInt& b); \ C10_API RetTy operator-(scalar_t a, const SymInt& b); \ C10_API RetTy operator*(scalar_t a, const SymInt& b); \ C10_API RetTy operator/(scalar_t a, const SymInt& b); \ C10_API bool operator==(const SymInt& a, scalar_t b); \ C10_API bool operator!=(const SymInt& a, scalar_t b); \ C10_API bool operator<(const SymInt& a, scalar_t b); \ C10_API bool operator<=(const SymInt& a, scalar_t b); \ C10_API bool operator>(const SymInt& a, scalar_t b); \ C10_API bool operator>=(const SymInt& a, scalar_t b); \ C10_API bool operator==(scalar_t a, const SymInt& b); \ C10_API bool operator!=(scalar_t a, const SymInt& b); \ C10_API bool operator<(scalar_t a, const SymInt& b); \ C10_API bool operator<=(scalar_t a, const SymInt& b); \ C10_API bool operator>(scalar_t a, const SymInt& b); \ C10_API bool operator>=(scalar_t a, const SymInt& b); DECLARE_SYMINT_OP_INTONLY(int64_t, SymInt) DECLARE_SYMINT_OP_INTONLY(int32_t, SymInt) DECLARE_SYMINT_OP_INTONLY(uint64_t, SymInt) DECLARE_SYMINT_OP_INTONLY(uint32_t, SymInt) DECLARE_SYMINT_OP(int64_t, SymInt) DECLARE_SYMINT_OP(int32_t, SymInt) // make sure constants work DECLARE_SYMINT_OP(uint64_t, SymInt) DECLARE_SYMINT_OP(uint32_t, SymInt) DECLARE_SYMINT_OP(double, SymFloat) DECLARE_SYMINT_OP(float, SymFloat) // just for completeness // On OSX size_t is different than uint64_t so we have to // define it separately #if defined(__APPLE__) DECLARE_SYMINT_OP_INTONLY(size_t, SymInt) DECLARE_SYMINT_OP(size_t, SymInt) #endif #undef DECLARE_SYMINT_OP C10_API std::ostream& operator<<(std::ostream& os, const SymInt& s); C10_API SymInt operator-(const SymInt& s); inline bool sym_eq(int64_t a, int64_t b) { return a == b; } inline SymBool sym_eq(const SymInt& a, const SymInt& b) { return a.sym_eq(b); } inline bool sym_ne(int64_t a, int64_t b) { return a != b; } inline SymBool sym_ne(const SymInt& a, const SymInt& b) { return a.sym_ne(b); } inline bool sym_lt(int64_t a, int64_t b) { return a < b; } inline SymBool sym_lt(const SymInt& a, const SymInt& b) { return a.sym_lt(b); } inline bool sym_le(int64_t a, int64_t b) { return a <= b; } inline SymBool sym_le(const SymInt& a, const SymInt& b) { return a.sym_le(b); } inline bool sym_gt(int64_t a, int64_t b) { return a > b; } inline SymBool sym_gt(const SymInt& a, const SymInt& b) { return a.sym_gt(b); } inline bool sym_ge(int64_t a, int64_t b) { return a >= b; } inline SymBool sym_ge(const SymInt& a, const SymInt& b) { return a.sym_ge(b); } } // namespace c10