pytorch/c10/core/ConstantSymNodeImpl.cpp
soulitzer eb8feb8ff8 Support SingletonSymNode mul with coefficient (#110369)
We want to be able to use SingletonSymNode to represent strides for Jagged layout tensor. The following is for 3D, but easily generalizable to higher dimensions.

Constraints:
- [B, x, D] (where x represents the "variably lengthed dim") can be strided in two ways [x, 1, sum(x)] and [dx, d, 1]. We need two different placeholder values depending on how the jagged tensor is strided.
- When doing operations we need the strides of output tensors to be expressable in terms of the strides and sizes of the inner tensors. Given [B, x, D] @ [D, D'], the output strides is [x * D', D', 1] rather than some opaque [x2, D', 1]. This constraint exists because if I'm tracing, I need a symint to represent the output stride. This symint needs to come from somewhere; I get it in several ways: (1) create a constant, (2) unbacked symint, (3) create a new input using a source, (4) output of an operation on an existing symint. It is clear that (4) is what we want here, which brings us to the design below.

Design:

Given the two constraints, the most straightforward way to implement this is actually to update SingletonSymNode to include some scalar factor, i.e. Morally, SingletonSymNode represents `factor * [s_0, s_1, …, s_n]` This enables us to symbolically compute strides from sizes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110369
Approved by: https://github.com/ezyang
ghstack dependencies: #110044
2023-10-04 22:56:15 +00:00

33 lines
1.2 KiB
C++

#include <c10/core/ConstantSymNodeImpl.h>
namespace c10 {
// This is used to support the case where the lhs is a constant symnode
// and the rhs is a singleton symnode. This situation occurs today when we
// perform a binary op between singleton int and plain int and the
// singleton promotes the int into a constant symnode. If we'd like to
// support more combinations in the future, we may need to implement some
// kind of multiple dispatch.
#define DEFINE_BINARY_OP(OP, ROP) \
template <typename T> \
c10::SymNode ConstantSymNodeImpl<T>::OP(const c10::SymNode& other) { \
TORCH_INTERNAL_ASSERT(other->singleton_int().has_value()); \
return other->ROP( \
c10::intrusive_ptr<ConstantSymNodeImpl<T>>::reclaim_copy(this)); \
}
DEFINE_BINARY_OP(eq, eq)
DEFINE_BINARY_OP(ne, ne)
DEFINE_BINARY_OP(ge, le)
DEFINE_BINARY_OP(le, ge)
DEFINE_BINARY_OP(lt, gt)
DEFINE_BINARY_OP(gt, lt)
DEFINE_BINARY_OP(mul, mul)
#undef DEFINE_BINARY_OP
template class ConstantSymNodeImpl<bool>;
template class ConstantSymNodeImpl<int64_t>;
} // namespace c10