mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix pybind11 problems with c10::SymInt unregistered (#88011)
Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/88011 Approved by: https://github.com/weiwangmeta, https://github.com/albanD
This commit is contained in:
parent
e667c00656
commit
d3c01c722d
|
|
@ -960,6 +960,7 @@ libtorch_python_core_sources = [
|
|||
"torch/csrc/utils/python_arg_parser.cpp",
|
||||
"torch/csrc/utils/python_dispatch.cpp",
|
||||
"torch/csrc/utils/python_symnode.cpp",
|
||||
"torch/csrc/utils/pybind.cpp",
|
||||
"torch/csrc/utils/structseq.cpp",
|
||||
"torch/csrc/utils/tensor_apply.cpp",
|
||||
"torch/csrc/utils/tensor_dtypes.cpp",
|
||||
|
|
|
|||
69
torch/csrc/utils/pybind.cpp
Normal file
69
torch/csrc/utils/pybind.cpp
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/csrc/utils/python_symnode.h>
|
||||
|
||||
namespace pybind11 {
|
||||
namespace detail {
|
||||
|
||||
bool type_caster<c10::SymInt>::load(py::handle src, bool) {
|
||||
if (torch::is_symint(src)) {
|
||||
value = c10::SymInt(static_cast<c10::SymNode>(
|
||||
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
|
||||
return true;
|
||||
}
|
||||
|
||||
auto raw_obj = src.ptr();
|
||||
if (THPUtils_checkIndex(raw_obj)) {
|
||||
value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
py::handle type_caster<c10::SymInt>::cast(
|
||||
c10::SymInt si,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
if (si.is_symbolic()) {
|
||||
// TODO: generalize this to work with C++ backed class
|
||||
auto* py_node =
|
||||
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
|
||||
TORCH_INTERNAL_ASSERT(py_node);
|
||||
return torch::get_symint_class()(py_node->getPyObj()).release();
|
||||
} else {
|
||||
return py::cast(si.as_int_unchecked()).release();
|
||||
}
|
||||
}
|
||||
|
||||
bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
|
||||
if (torch::is_symfloat(src)) {
|
||||
value = c10::SymFloat(static_cast<c10::SymNode>(
|
||||
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
|
||||
return true;
|
||||
}
|
||||
|
||||
auto raw_obj = src.ptr();
|
||||
if (THPUtils_checkDouble(raw_obj)) {
|
||||
value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
py::handle type_caster<c10::SymFloat>::cast(
|
||||
c10::SymFloat si,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
if (si.is_symbolic()) {
|
||||
// TODO: generalize this to work with C++ backed class
|
||||
auto* py_node =
|
||||
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
|
||||
TORCH_INTERNAL_ASSERT(py_node);
|
||||
return torch::get_symfloat_class()(py_node->getPyObj()).release();
|
||||
} else {
|
||||
return py::cast(si.as_float_unchecked()).release();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
|
@ -187,6 +187,30 @@ struct type_caster<c10::DispatchKey>
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<c10::SymInt> {
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(c10::SymInt, _("SymInt"));
|
||||
bool load(py::handle src, bool);
|
||||
|
||||
static py::handle cast(
|
||||
c10::SymInt si,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<c10::SymFloat> {
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(c10::SymFloat, _("SymFloat"));
|
||||
bool load(py::handle src, bool);
|
||||
|
||||
static py::handle cast(
|
||||
c10::SymFloat si,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */);
|
||||
};
|
||||
|
||||
// Pybind11 bindings for our optional and variant types.
|
||||
// http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -79,82 +79,6 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace pybind11 {
|
||||
namespace detail {
|
||||
template <>
|
||||
struct type_caster<c10::SymInt> {
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(c10::SymInt, _("SymInt"));
|
||||
bool load(py::handle src, bool) {
|
||||
if (torch::is_symint(src)) {
|
||||
value = c10::SymInt(static_cast<c10::SymNode>(
|
||||
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(
|
||||
src.attr("node"))));
|
||||
return true;
|
||||
}
|
||||
|
||||
auto raw_obj = src.ptr();
|
||||
if (THPUtils_checkIndex(raw_obj)) {
|
||||
value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static py::handle cast(
|
||||
c10::SymInt si,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
if (si.is_symbolic()) {
|
||||
// TODO: generalize this to work with C++ backed class
|
||||
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
|
||||
si.toSymNodeImpl().get());
|
||||
TORCH_INTERNAL_ASSERT(py_node);
|
||||
return torch::get_symint_class()(py_node->getPyObj()).release();
|
||||
} else {
|
||||
return py::cast(si.as_int_unchecked()).release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<c10::SymFloat> {
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(c10::SymFloat, _("SymFloat"));
|
||||
bool load(py::handle src, bool) {
|
||||
if (torch::is_symfloat(src)) {
|
||||
value = c10::SymFloat(static_cast<c10::SymNode>(
|
||||
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(
|
||||
src.attr("node"))));
|
||||
return true;
|
||||
}
|
||||
|
||||
auto raw_obj = src.ptr();
|
||||
if (THPUtils_checkDouble(raw_obj)) {
|
||||
value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static py::handle cast(
|
||||
c10::SymFloat si,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
if (si.is_symbolic()) {
|
||||
// TODO: generalize this to work with C++ backed class
|
||||
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
|
||||
si.toSymNodeImpl().get());
|
||||
TORCH_INTERNAL_ASSERT(py_node);
|
||||
return torch::get_symfloat_class()(py_node->getPyObj()).release();
|
||||
} else {
|
||||
return py::cast(si.as_float_unchecked()).release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
||||
inline bool THPUtils_checkScalar(PyObject* obj) {
|
||||
#ifdef USE_NUMPY
|
||||
if (torch::utils::is_numpy_scalar(obj)) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user