pytorch/torch/csrc/utils/pybind.h
2017-09-05 17:48:55 -04:00

32 lines
702 B
C++

#pragma once
#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "torch/csrc/DynamicTypes.h"
namespace py = pybind11;
namespace pybind11 { namespace detail {
// handle Tensor <-> at::Tensor conversions
template <> struct type_caster<at::Tensor> {
public:
PYBIND11_TYPE_CASTER(at::Tensor, _("at::Tensor"));
bool load(handle src, bool) {
try {
value = torch::createTensor(src.ptr());
} catch (std::exception& e) {
return false;
}
return true;
}
static handle cast(at::Tensor src, return_value_policy /* policy */, handle /* parent */) {
return handle(torch::createPyObject(src));
}
};
}} // namespace pybind11::detail