mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
I have some test code in there as well, along with a script "test_libtorch" to run it. You'll need to modify `test_libtorch` to point to where you have `pytorch` built. I currently require that `pybind11` is included as a subdirectory of the test, but added it to the `.gitignore` to make this reviewable.
Currently, something like this works:
```cpp
struct Foo {
int x, y;
Foo(): x(2), y(5){}
Foo(int x_, int y_) : x(x_), y(y_) {}
void display() {
cout<<"x: "<<x<<' '<<"y: "<<y<<endl;
}
int64_t add(int64_t z) {
return (x+y)*z;
}
};
static auto test = torch::jit::class_<Foo>("Foo")
.def(torch::jit::init<int64_t, int64_t>())
.def("display", &Foo::display)
.def("add", &Foo::add)
.def("combine", &Foo::combine);
```
with
```py
torch.jit.script
def f(x):
val = torch._C.Foo(5, 3)
val.display()
print(val.add(3))
```
results in
```
x: 5 y: 3
24
```
Current issues:
- [x] The python class created by torchscript doesn't interactly properly with the surrounding code.
```
torch.jit.script
def f(x):
val = torch._C.Foo(5, 3)
return val
```
- [x] Doesn't properly take in non-pointer classes. Can't define this function signature in cpp (We don't want to support this I believe).
```cpp
void combine(Foo x) {
```
- [x] Has some issues with memory for blobs when constructing multiple objects (fix constant propagation pass to not treat capsules as the same object).
```py
torch.jit.script
def f(x):
val = torch._C.Foo(5, 3)
val2 = torch._C.Foo(100, 0)
val.display()
print(val.add(3))
```
- [ ] Can't define multiple constructors (need to define overload string. Currently not possible since we don't support overloaded methods).
- [x] `init` is a little bit different syntax than `pybind`. `.init<...>()` instead of `.def(py::init<>())`
- [x] I couldn't figure out how to add some files into the build so they'd be copied to the `include/` directories, so I symlinked them manually.
- [ ] Currently, the conversion from Python into Torchscript doesn't work.
- [ ] Torchbind also currently requires Python/Pybind dependency. Fixing this would probably involve some kind of macro to bind into Python when possible.
- [ ] We pass back into Python by value, currently. There's no way of passing by reference.
- [x] Currently can only register one method with the same type signature. This is because we create a `static auto opRegistry`, and the function is templated on the type signature.
Somewhat blocked on https://github.com/pytorch/pytorch/pull/21177. We currently use some structures that will be refactored by his PR (namely `return_type_to_ivalue` and `ivalue_to_arg_type`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21098
Differential Revision: D16634872
Pulled By: Chillee
fbshipit-source-id: 1408bb89ea649c27d560df59e2cf9920467fe1de
208 lines
7.0 KiB
C++
208 lines
7.0 KiB
C++
|
|
#pragma once
|
|
|
|
#include <ATen/core/function_schema.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <ATen/core/jit_type.h>
|
|
#include <ATen/core/op_registration/op_registration.h>
|
|
#include <ATen/core/stack.h>
|
|
#include <c10/util/C++17.h>
|
|
#include <c10/util/Metaprogramming.h>
|
|
#include <c10/util/TypeList.h>
|
|
#include <pybind11/pybind11.h>
|
|
#include <torch/csrc/jit/operator.h>
|
|
#include <torch/csrc/jit/pybind_utils.h>
|
|
#include <torch/csrc/jit/script/compilation_unit.h>
|
|
#include <torch/csrc/jit/tracer.h>
|
|
#include <torch/csrc/utils/variadic.h>
|
|
#include <iostream>
|
|
#include <sstream>
|
|
|
|
|
|
namespace py = pybind11;
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
static std::vector<c10::RegisterOperators> registeredOps;
|
|
|
|
namespace detail {
|
|
template <class R, class...>
|
|
struct types {
|
|
constexpr static bool hasRet = true;
|
|
using type = types;
|
|
};
|
|
template <class... args>
|
|
struct types<void, args...> {
|
|
constexpr static bool hasRet = false;
|
|
using type = types;
|
|
};
|
|
template <class Sig>
|
|
struct args;
|
|
template <class R, class CurClass, class... Args>
|
|
struct args<R (CurClass::*)(Args...)> : types<R, Args...> {};
|
|
template <class Sig>
|
|
using args_t = typename args<Sig>::type;
|
|
} // namespace detail
|
|
template <class... Types>
|
|
detail::types<void, Types...> init() { return detail::types<void, Types...>{}; }
|
|
|
|
// To bind custom classes into Torchscript, use an API very similar to Pybind's.
|
|
// Currently exposes one class `torch::jit::class_<T>` and 2 methods.
|
|
// - Constructing `torch::jit::class_<Foo>` registers `Foo` in Python and
|
|
// Torchscript, and puts it under `torch.classes.Foo` in Python.
|
|
// - torch::jit::class_<Foo>.def("method1", &Foo::method1) does some template
|
|
// metaprogramming to introspect the function types and register the operator
|
|
// for use in Torchscript.
|
|
// - torch::jit::class_<Foo>.def(torch::jit::init<int64_t, int64_t>()) registers
|
|
// the Foo(int, int) constructor.
|
|
// see test/custom_operator/classes.cpp and
|
|
// test/custom_operator/test_custom_classes.py for example usages
|
|
|
|
template <class CurClass>
|
|
class class_ {
|
|
std::string className;
|
|
std::string qualClassName;
|
|
c10::optional<py::class_<CurClass>> pyClass = c10::nullopt;
|
|
std::shared_ptr<script::CompilationUnit> classCu = nullptr;
|
|
ClassTypePtr classTypePtr;
|
|
|
|
const std::string parentModule = "classes";
|
|
const std::string topModule = "__torch__.torch";
|
|
|
|
public:
|
|
class_(string className_) : className(std::move(className_)) {
|
|
// Currently we register everything as a python class just for convenience.
|
|
// We'll want to remove this at some point to get rid of the python
|
|
// dependency. It would require significant changes to class registration,
|
|
// (I think)?
|
|
qualClassName = topModule + "." + parentModule + "." + className;
|
|
|
|
auto obj = py::module::import("torch").attr(parentModule.c_str());
|
|
pyClass = py::class_<CurClass>(obj, className.c_str());
|
|
pyClass->attr("qualified_name") = py::str(qualClassName);
|
|
auto newClass =
|
|
py::module::import("torch.jit")
|
|
.attr("_add_script_class")(*pyClass, qualClassName.c_str());
|
|
|
|
auto castToPython = [](void* objPtr) -> PyObject* {
|
|
CurClass x = *static_cast<CurClass*>(objPtr);
|
|
auto py_object = py::cast(x);
|
|
PyObject* rawPyObj = py_object.release().ptr();
|
|
return rawPyObj;
|
|
};
|
|
getClassConverter()[qualClassName] = castToPython;
|
|
|
|
// We currently represent custom classes as torchscript classes with a
|
|
// capsule attribute
|
|
classCu = torch::jit::get_python_cu();
|
|
classTypePtr =
|
|
ClassType::create(c10::QualifiedName(qualClassName), classCu);
|
|
classTypePtr->addAttribute("capsule", CapsuleType::get());
|
|
|
|
c10::getCustomClassTypeMap().insert({typeid(c10::intrusive_ptr<CurClass>).name(),
|
|
StrongTypePtr(classCu, classTypePtr)});
|
|
c10::getCustomClassTypeMap().insert({typeid(c10::tagged_capsule<CurClass>).name(),
|
|
StrongTypePtr(classCu, classTypePtr)});
|
|
|
|
classCu->register_class(classTypePtr);
|
|
}
|
|
|
|
template <typename... Types>
|
|
class_& def(detail::types<void, Types...>) { // Used in combination with
|
|
// torch::jit::init<...>()
|
|
pyClass->def(py::init<Types...>());
|
|
|
|
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
|
|
auto classObj = c10::make_intrusive<CurClass>(args...);
|
|
auto genericPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(classObj);
|
|
auto capsule = IValue(genericPtr);
|
|
auto object = self.ivalue.toObject();
|
|
object->setSlot(0, capsule);
|
|
};
|
|
|
|
defineMethod<void>("__init__", std::move(func), false);
|
|
return *this;
|
|
}
|
|
template <typename Func>
|
|
class_& def(string name, Func f) {
|
|
auto res = def_(name, f, detail::args_t<decltype(f)>{});
|
|
return *this;
|
|
}
|
|
|
|
private:
|
|
template <class T>
|
|
struct addInput {
|
|
static Value* call(std::shared_ptr<Graph> graph) {
|
|
return graph->addInput()->setType(getTypePtr<T>());
|
|
}
|
|
};
|
|
template <class Func, size_t... arg_indices>
|
|
std::vector<Value*> addInputs_(
|
|
Func f,
|
|
std::shared_ptr<Graph> graph,
|
|
guts::index_sequence<arg_indices...>) {
|
|
using argTypes =
|
|
typename guts::infer_function_traits_t<Func>::parameter_types;
|
|
std::vector<Value*> res = {
|
|
addInput<guts::typelist::element_t<arg_indices, argTypes>>::call(
|
|
graph)...};
|
|
return res;
|
|
}
|
|
template <class Func>
|
|
std::vector<Value*> addInputs(Func f, std::shared_ptr<Graph> graph) {
|
|
constexpr auto numArgs =
|
|
guts::infer_function_traits_t<Func>::number_of_parameters;
|
|
return addInputs_(f, graph, guts::make_index_sequence<numArgs>());
|
|
}
|
|
|
|
template <typename Last>
|
|
std::string type_name() {
|
|
return std::string(typeid(Last).name());
|
|
}
|
|
template <typename First, typename Second, typename... Rest>
|
|
std::string type_name() {
|
|
return type_name<First>() + "_" + type_name<Second, Rest...>();
|
|
}
|
|
|
|
template <class T>
|
|
void addType(Value* v) {
|
|
v->setType(getTypePtr<T>());
|
|
}
|
|
template<typename R, typename Func>
|
|
void defineMethod(std::string name, Func func, bool hasRet) {
|
|
auto graph = std::make_shared<Graph>();
|
|
auto qualFuncName = className + "::" + name;
|
|
registeredOps.push_back(
|
|
torch::RegisterOperators().op(qualFuncName, std::move(func)));
|
|
|
|
|
|
std::vector<Value*> inputs = addInputs(func, graph);
|
|
auto methodCall = graph->insertNode(graph->create(
|
|
Symbol::fromQualString(qualFuncName), inputs, hasRet));
|
|
Value* res;
|
|
if (hasRet) {
|
|
res = methodCall->output();
|
|
addType<R>(res);
|
|
} else {
|
|
res = graph->insertConstant(IValue())->setType(NoneType::get());
|
|
}
|
|
graph->registerOutput(res);
|
|
|
|
classCu->create_function(qualClassName + "." + name, graph);
|
|
}
|
|
template <typename Func, typename R, typename... Types>
|
|
class_& def_(string name, Func f, detail::types<R, Types...> funcInfo) {
|
|
pyClass->def(name.c_str(), f);
|
|
|
|
auto func = [f](c10::intrusive_ptr<CurClass> cur, Types... args) {
|
|
return guts::invoke(f, *cur, args...);
|
|
};
|
|
defineMethod<R>(name, std::move(func), funcInfo.hasRet);
|
|
return *this;
|
|
}
|
|
};
|
|
|
|
} // namespace jit
|
|
|
|
} // namespace torch
|