#include #include #include namespace pybind11 { namespace detail { bool type_caster::load(py::handle src, bool) { if (torch::is_symint(src)) { value = c10::SymInt(static_cast( c10::make_intrusive(src.attr("node")))); return true; } auto raw_obj = src.ptr(); if (THPUtils_checkIndex(raw_obj)) { value = c10::SymInt{THPUtils_unpackIndex(raw_obj)}; return true; } return false; } py::handle type_caster::cast( c10::SymInt si, return_value_policy /* policy */, handle /* parent */) { if (si.is_symbolic()) { // TODO: generalize this to work with C++ backed class auto* py_node = dynamic_cast(si.toSymNodeImpl().get()); TORCH_INTERNAL_ASSERT(py_node); return torch::get_symint_class()(py_node->getPyObj()).release(); } else { return py::cast(si.as_int_unchecked()).release(); } } bool type_caster::load(py::handle src, bool) { if (torch::is_symfloat(src)) { value = c10::SymFloat(static_cast( c10::make_intrusive(src.attr("node")))); return true; } auto raw_obj = src.ptr(); if (THPUtils_checkDouble(raw_obj)) { value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)}; return true; } return false; } py::handle type_caster::cast( c10::SymFloat si, return_value_policy /* policy */, handle /* parent */) { if (si.is_symbolic()) { // TODO: generalize this to work with C++ backed class auto* py_node = dynamic_cast(si.toSymNodeImpl().get()); TORCH_INTERNAL_ASSERT(py_node); return torch::get_symfloat_class()(py_node->getPyObj()).release(); } else { return py::cast(si.as_float_unchecked()).release(); } } } // namespace detail } // namespace pybind11