#pragma once #include "torch/csrc/python_headers.h" #include #include #include #include "torch/csrc/DynamicTypes.h" #include "torch/csrc/autograd/python_variable.h" #include namespace py = pybind11; namespace pybind11 { namespace detail { // torch.autograd.Variable <-> at::Tensor conversions (without unwrapping) template <> struct type_caster { public: PYBIND11_TYPE_CASTER(at::Tensor, _("at::Tensor")); bool load(handle src, bool) { PyObject* obj = src.ptr(); if (THPVariable_Check(obj)) { value = reinterpret_cast(obj)->cdata; return true; } return false; } static handle cast(at::Tensor src, return_value_policy /* policy */, handle /* parent */) { if (!torch::autograd::is_variable(src)) { throw std::runtime_error( "Expected tensor's dynamic type to be Variable, not Tensor"); } return handle(THPVariable_Wrap(torch::autograd::Variable(src))); } }; template<> struct type_caster { public: PYBIND11_TYPE_CASTER(torch::autograd::Variable, _("torch::autograd::Variable")); bool load(handle src, bool) { PyObject *source = src.ptr(); if (THPVariable_Check(source)) { value = ((THPVariable*)source)->cdata; return true; } else { return false; } } static handle cast(torch::autograd::Variable src, return_value_policy /* policy */, handle /* parent */) { return handle(THPVariable_Wrap(src)); } }; // http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers template struct type_caster> : optional_caster> {}; }} // namespace pybind11::detail