Fix pybind11 problems with c10::SymInt unregistered (#88011)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88011
Approved by: https://github.com/weiwangmeta, https://github.com/albanD
This commit is contained in:
Edward Z. Yang 2022-10-28 17:20:10 -04:00 committed by PyTorch MergeBot
parent e667c00656
commit d3c01c722d
4 changed files with 94 additions and 76 deletions

View File

@ -960,6 +960,7 @@ libtorch_python_core_sources = [
"torch/csrc/utils/python_arg_parser.cpp",
"torch/csrc/utils/python_dispatch.cpp",
"torch/csrc/utils/python_symnode.cpp",
"torch/csrc/utils/pybind.cpp",
"torch/csrc/utils/structseq.cpp",
"torch/csrc/utils/tensor_apply.cpp",
"torch/csrc/utils/tensor_dtypes.cpp",

View File

@ -0,0 +1,69 @@
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_symnode.h>
namespace pybind11 {
namespace detail {
bool type_caster<c10::SymInt>::load(py::handle src, bool) {
if (torch::is_symint(src)) {
value = c10::SymInt(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
return true;
}
auto raw_obj = src.ptr();
if (THPUtils_checkIndex(raw_obj)) {
value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
return true;
}
return false;
}
py::handle type_caster<c10::SymInt>::cast(
c10::SymInt si,
return_value_policy /* policy */,
handle /* parent */) {
if (si.is_symbolic()) {
// TODO: generalize this to work with C++ backed class
auto* py_node =
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symint_class()(py_node->getPyObj()).release();
} else {
return py::cast(si.as_int_unchecked()).release();
}
}
bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
if (torch::is_symfloat(src)) {
value = c10::SymFloat(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
return true;
}
auto raw_obj = src.ptr();
if (THPUtils_checkDouble(raw_obj)) {
value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
return true;
}
return false;
}
py::handle type_caster<c10::SymFloat>::cast(
c10::SymFloat si,
return_value_policy /* policy */,
handle /* parent */) {
if (si.is_symbolic()) {
// TODO: generalize this to work with C++ backed class
auto* py_node =
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symfloat_class()(py_node->getPyObj()).release();
} else {
return py::cast(si.as_float_unchecked()).release();
}
}
} // namespace detail
} // namespace pybind11

View File

@ -187,6 +187,30 @@ struct type_caster<c10::DispatchKey>
}
};
template <>
struct type_caster<c10::SymInt> {
public:
PYBIND11_TYPE_CASTER(c10::SymInt, _("SymInt"));
bool load(py::handle src, bool);
static py::handle cast(
c10::SymInt si,
return_value_policy /* policy */,
handle /* parent */);
};
template <>
struct type_caster<c10::SymFloat> {
public:
PYBIND11_TYPE_CASTER(c10::SymFloat, _("SymFloat"));
bool load(py::handle src, bool);
static py::handle cast(
c10::SymFloat si,
return_value_policy /* policy */,
handle /* parent */);
};
// Pybind11 bindings for our optional and variant types.
// http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
template <typename T>

View File

@ -79,82 +79,6 @@
#include <string>
#include <vector>
namespace pybind11 {
namespace detail {
template <>
struct type_caster<c10::SymInt> {
public:
PYBIND11_TYPE_CASTER(c10::SymInt, _("SymInt"));
bool load(py::handle src, bool) {
if (torch::is_symint(src)) {
value = c10::SymInt(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(
src.attr("node"))));
return true;
}
auto raw_obj = src.ptr();
if (THPUtils_checkIndex(raw_obj)) {
value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
return true;
}
return false;
}
static py::handle cast(
c10::SymInt si,
return_value_policy /* policy */,
handle /* parent */) {
if (si.is_symbolic()) {
// TODO: generalize this to work with C++ backed class
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symint_class()(py_node->getPyObj()).release();
} else {
return py::cast(si.as_int_unchecked()).release();
}
}
};
template <>
struct type_caster<c10::SymFloat> {
public:
PYBIND11_TYPE_CASTER(c10::SymFloat, _("SymFloat"));
bool load(py::handle src, bool) {
if (torch::is_symfloat(src)) {
value = c10::SymFloat(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(
src.attr("node"))));
return true;
}
auto raw_obj = src.ptr();
if (THPUtils_checkDouble(raw_obj)) {
value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
return true;
}
return false;
}
static py::handle cast(
c10::SymFloat si,
return_value_policy /* policy */,
handle /* parent */) {
if (si.is_symbolic()) {
// TODO: generalize this to work with C++ backed class
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symfloat_class()(py_node->getPyObj()).release();
} else {
return py::cast(si.as_float_unchecked()).release();
}
}
};
} // namespace detail
} // namespace pybind11
inline bool THPUtils_checkScalar(PyObject* obj) {
#ifdef USE_NUMPY
if (torch::utils::is_numpy_scalar(obj)) {