mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/88011 Approved by: https://github.com/weiwangmeta, https://github.com/albanD
70 lines
2.0 KiB
C++
70 lines
2.0 KiB
C++
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/csrc/utils/python_arg_parser.h>
|
|
#include <torch/csrc/utils/python_symnode.h>
|
|
|
|
namespace pybind11 {
|
|
namespace detail {
|
|
|
|
bool type_caster<c10::SymInt>::load(py::handle src, bool) {
|
|
if (torch::is_symint(src)) {
|
|
value = c10::SymInt(static_cast<c10::SymNode>(
|
|
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
|
|
return true;
|
|
}
|
|
|
|
auto raw_obj = src.ptr();
|
|
if (THPUtils_checkIndex(raw_obj)) {
|
|
value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
py::handle type_caster<c10::SymInt>::cast(
|
|
c10::SymInt si,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
if (si.is_symbolic()) {
|
|
// TODO: generalize this to work with C++ backed class
|
|
auto* py_node =
|
|
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
|
|
TORCH_INTERNAL_ASSERT(py_node);
|
|
return torch::get_symint_class()(py_node->getPyObj()).release();
|
|
} else {
|
|
return py::cast(si.as_int_unchecked()).release();
|
|
}
|
|
}
|
|
|
|
bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
|
|
if (torch::is_symfloat(src)) {
|
|
value = c10::SymFloat(static_cast<c10::SymNode>(
|
|
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
|
|
return true;
|
|
}
|
|
|
|
auto raw_obj = src.ptr();
|
|
if (THPUtils_checkDouble(raw_obj)) {
|
|
value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
py::handle type_caster<c10::SymFloat>::cast(
|
|
c10::SymFloat si,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
if (si.is_symbolic()) {
|
|
// TODO: generalize this to work with C++ backed class
|
|
auto* py_node =
|
|
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
|
|
TORCH_INTERNAL_ASSERT(py_node);
|
|
return torch::get_symfloat_class()(py_node->getPyObj()).release();
|
|
} else {
|
|
return py::cast(si.as_float_unchecked()).release();
|
|
}
|
|
}
|
|
|
|
} // namespace detail
|
|
} // namespace pybind11
|