pytorch/torch/csrc/jit/pybind.h
Zachary DeVito 2da43bf6f1 Make Symbol a true struct (#4717)
Previous Symbol was just a uint32_t and we converts symbolToString and
stringToSymbol. Now Symbol is a struct with a toString method, and
constructors from either BuiltinSymbols enums (e.g. kParam) or strings.

Symbol is convertible to a uint32_t to ensure it can still be used in
switch statement BuiltinSymbol case branches.
2018-01-17 21:49:28 -08:00

98 lines
3.0 KiB
C++

#pragma once
#include <Python.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/THP.h"
namespace py = pybind11;
namespace pybind11 { namespace detail {
template<> struct type_caster<torch::jit::tracer::TraceInput> {
public:
PYBIND11_TYPE_CASTER(torch::jit::tracer::TraceInput, _("torch::jit::tracer::TraceInput"));
bool load(handle src, bool) {
PyObject *source = src.ptr();
if (THPVariable_Check(source)) {
value = torch::jit::tracer::TraceInput(((THPVariable*)source)->cdata);
return true;
} else if (THPModule_isTensor(source)) {
value = torch::jit::tracer::TraceInput(torch::createTensor(source));
return true;
} else {
return false;
}
}
static handle cast(torch::jit::tracer::TraceInput src, return_value_policy /* policy */, handle /* parent */) {
if (src.variable.defined()) {
return handle(THPVariable_Wrap(src.variable));
} else {
return handle(torch::createPyObject(src.buffer));
}
}
};
template<> struct type_caster<torch::autograd::Variable> {
public:
PYBIND11_TYPE_CASTER(torch::autograd::Variable, _("torch::autograd::Variable"));
bool load(handle src, bool) {
PyObject *source = src.ptr();
if (THPVariable_Check(source)) {
value = ((THPVariable*)source)->cdata;
return true;
} else {
return false;
}
}
static handle cast(torch::autograd::Variable src, return_value_policy /* policy */, handle /* parent */) {
return handle(THPVariable_Wrap(src));
}
};
template <> struct type_caster<torch::jit::Symbol> {
public:
PYBIND11_TYPE_CASTER(torch::jit::Symbol, _("Symbol"));
bool load(handle src, bool) {
try {
value = torch::jit::Symbol(py::cast<std::string>(src));
} catch (std::exception& e) {
return false;
}
return true;
}
static handle cast(torch::jit::Symbol src, return_value_policy /* policy */, handle /* parent */) {
return py::cast(std::string(src.toString()), return_value_policy::copy).release();
}
};
template <> struct type_caster<torch::jit::AttributeKind> {
public:
PYBIND11_TYPE_CASTER(torch::jit::AttributeKind, _("AttributeKind"));
bool load(handle src, bool) {
return false;
}
static handle cast(torch::jit::AttributeKind src, return_value_policy /* policy */, handle /* parent */) {
return py::cast(std::string(torch::jit::toString(src)), return_value_policy::copy).release();
}
};
// See https://github.com/pybind/pybind11/issues/637
using ListCasterBase = pybind11::detail::list_caster<std::vector<torch::jit::Node *>, torch::jit::Node *>;
template<> struct type_caster<std::vector<torch::jit::Node *>> : ListCasterBase {
static handle cast(const std::vector<torch::jit::Node *> &src, return_value_policy, handle parent) {
return ListCasterBase::cast(src, return_value_policy::reference, parent);
}
static handle cast(const std::vector<torch::jit::Node *> *src, return_value_policy pol, handle parent) {
return cast(*src, pol, parent);
}
};
}} // namespace pybind11::detail