#pragma once #include "torch/csrc/python_headers.h" #include #include #include #include "torch/csrc/DynamicTypes.h" #include "torch/csrc/autograd/python_variable.h" #include "torch/csrc/utils/python_tuples.h" #include "torch/csrc/utils/python_numbers.h" #include namespace py = pybind11; namespace pybind11 { namespace detail { // torch.autograd.Variable <-> at::Tensor conversions (without unwrapping) template <> struct type_caster { public: PYBIND11_TYPE_CASTER(at::Tensor, _("at::Tensor")); bool load(handle src, bool) { PyObject* obj = src.ptr(); if (THPVariable_Check(obj)) { value = reinterpret_cast(obj)->cdata; return true; } return false; } static handle cast(at::Tensor src, return_value_policy /* policy */, handle /* parent */) { if (!src.is_variable()) { throw std::runtime_error( "Expected tensor's dynamic type to be Variable, not Tensor"); } return handle(THPVariable_Wrap(torch::autograd::Variable(src))); } }; template<> struct type_caster { public: PYBIND11_TYPE_CASTER(torch::autograd::Variable, _("torch::autograd::Variable")); bool load(handle src, bool) { PyObject *source = src.ptr(); if (THPVariable_Check(source)) { value = ((THPVariable*)source)->cdata; return true; } else { return false; } } static handle cast(torch::autograd::Variable src, return_value_policy /* policy */, handle /* parent */) { return handle(THPVariable_Wrap(src)); } }; template<> struct type_caster { public: PYBIND11_TYPE_CASTER(at::IntList, _("at::IntList")); bool load(handle src, bool) { PyObject *source = src.ptr(); auto tuple = PyTuple_Check(source); if (tuple || PyList_Check(source)) { auto size = tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source); v_value.resize(size); for (int idx = 0; idx < size; idx++) { PyObject* obj = tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx); if (THPVariable_Check(obj)) { v_value[idx] = THPVariable_Unpack(obj).item(); } else if (PyLong_Check(obj)) { // use THPUtils_unpackLong after it is safe to include python_numbers.h v_value[idx] = THPUtils_unpackLong(obj); } else { return false; } } value = v_value; return true; } return false; } static handle cast(at::IntList src, return_value_policy /* policy */, handle /* parent */) { return handle(THPUtils_packInt64Array(src.size(), src.data())); } private: std::vector v_value; }; // http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers template struct type_caster> : optional_caster> {}; }} // namespace pybind11::detail