pytorch/torch/csrc/utils/pybind.cpp
Edward Z. Yang 756a86d52c Support large negative SymInt (#99157)
The strategy is that we will heap allocate a LargeNegativeIntSymNodeImpl whenever we have a large negative int, so that we can keep the old `is_symbolic` test (now called `is_heap_allocated`) on SymInt. Whenever we need to do something with these ints, though, we convert them back into a plain `int64_t` (and then, e.g., wrap it in whatever user specificed SymNodeImpl they need.) We cannot wrap directly in the user specified SymNodeImpl as we generally do not know what the "tracing context" is from C++. We expect large negative ints to be rare, so we don't apply optimizations like singleton-ifying INT_MIN.  Here's the order to review:

* c10/core/SymInt.h and cpp
  * `is_symbolic` renamed to `is_heap_allocated` as I needed to audit all use sites: the old `is_symbolic` test would return true for large negative int, but it would be wrong to then try to dispatch on the LargeNegativeIntSymNodeImpl which supports very few operations. In this file, I had to update expect_int,
  * If you pass in a large negative integer, we instead heap allocate it in `promote_to_negative`. The function is written in a funny way to keep compact constructor code for SymInt (the heap allocation happens out of line)
  * clone is now moved out-of-line
  * New method maybe_as_int which will give you a constant int if it is possible, either because it's stored inline or in LargeNegativeIntSymNodeImpl. This is the preferred replacement for previous use of is_symbolic() and then as_int_unchecked().
  * Rename toSymNodeImpl to toSymNode, which is more correct (since it returns a SymNode)
  * Complete rewrite of `normalize_symints.cpp` to use new `maybe_as_int`. Cannot easily use the old code structure, so it's now done doing a macro and typing out each case manually (it's actually not that bad.)
  * Reimplementations of all the unary operators by hand to use `maybe_as_int`, relatively simple.
* c10/core/LargeNegativeIntSymNodeImpl.h - Just stores a int64_t value, but it has to be big and negative. Most methods are not implemented, since we will rewrap the large negative int in the real SymNodeImpl subclass before doing operations with it
* The rest of the files are just rewriting code to use `maybe_as_int`. There is a nontrivial comment in c10/core/SymIntArrayRef.h

Very minor test adjustment in c10/test/core/SymInt_test.cpp . Plan to exercise this properly in next PR.

Companion XLA PR: https://github.com/pytorch/xla/pull/4882

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99157
Approved by: https://github.com/albanD
2023-04-15 22:43:51 +00:00

151 lines
4.5 KiB
C++

#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_symnode.h>
namespace pybind11 {
namespace detail {
bool type_caster<c10::SymInt>::load(py::handle src, bool) {
if (torch::is_symint(src)) {
auto node = src.attr("node");
if (py::isinstance<c10::SymNodeImpl>(node)) {
value = c10::SymInt(py::cast<c10::SymNode>(node));
return true;
}
value = c10::SymInt(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(node)));
return true;
}
auto raw_obj = src.ptr();
if (THPUtils_checkIndex(raw_obj)) {
value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
return true;
}
return false;
}
py::handle type_caster<c10::SymInt>::cast(
c10::SymInt si,
return_value_policy /* policy */,
handle /* parent */) {
if (auto m = si.maybe_as_int()) {
return py::cast(*m).release();
} else {
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
si.toSymNodeImplUnowned());
if (py_node) {
// Return the Python directly (unwrap)
return torch::get_symint_class()(py_node->getPyObj()).release();
} else {
// Wrap the C++ into Python
auto inner = py::cast(si.toSymNode());
if (!inner) {
throw python_error();
}
return torch::get_symint_class()(inner).release();
}
}
}
bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
if (torch::is_symfloat(src)) {
value = c10::SymFloat(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
return true;
}
auto raw_obj = src.ptr();
if (THPUtils_checkDouble(raw_obj)) {
value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
return true;
}
return false;
}
py::handle type_caster<c10::SymFloat>::cast(
c10::SymFloat si,
return_value_policy /* policy */,
handle /* parent */) {
if (si.is_symbolic()) {
// TODO: generalize this to work with C++ backed class
auto* py_node =
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symfloat_class()(py_node->getPyObj()).release();
} else {
return py::cast(si.as_float_unchecked()).release();
}
}
bool type_caster<c10::SymBool>::load(py::handle src, bool) {
if (torch::is_symbool(src)) {
value = c10::SymBool(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
return true;
}
auto raw_obj = src.ptr();
if (THPUtils_checkBool(raw_obj)) {
value = c10::SymBool{THPUtils_unpackBool(raw_obj)};
return true;
}
return false;
}
py::handle type_caster<c10::SymBool>::cast(
c10::SymBool si,
return_value_policy /* policy */,
handle /* parent */) {
if (si.is_symbolic()) {
// TODO: generalize this to work with C++ backed class
auto* py_node =
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symbool_class()(py_node->getPyObj()).release();
} else {
return py::cast(si.as_bool_unchecked()).release();
}
}
bool type_caster<c10::Scalar>::load(py::handle src, bool) {
TORCH_INTERNAL_ASSERT(
0, "pybind11 loading for c10::Scalar NYI (file a bug if you need it)");
}
py::handle type_caster<c10::Scalar>::cast(
const c10::Scalar& scalar,
return_value_policy /* policy */,
handle /* parent */) {
if (scalar.isIntegral(/*includeBool*/ false)) {
// We have to be careful here; we cannot unconditionally route through
// SymInt because integer data from Tensors can easily be MIN_INT or
// very negative, which conflicts with the allocated range.
if (scalar.isSymbolic()) {
return py::cast(scalar.toSymInt()).release();
} else {
return py::cast(scalar.toLong()).release();
}
} else if (scalar.isFloatingPoint()) {
// This isn't strictly necessary but we add it for symmetry
if (scalar.isSymbolic()) {
return py::cast(scalar.toSymFloat()).release();
} else {
return py::cast(scalar.toDouble()).release();
}
} else if (scalar.isBoolean()) {
if (scalar.isSymbolic()) {
return py::cast(scalar.toSymBool()).release();
}
return py::cast(scalar.toBool()).release();
} else if (scalar.isComplex()) {
return py::cast(scalar.toComplexDouble()).release();
} else {
TORCH_INTERNAL_ASSERT(0, "unrecognized scalar type ", scalar.type());
}
}
} // namespace detail
} // namespace pybind11