mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34848 Test Plan: Imported from OSS Differential Revision: D20480514 Pulled By: jamesr66a fbshipit-source-id: 1c595faf34e00aab0a6202a8902426bd310551c3
206 lines
7.8 KiB
C++
206 lines
7.8 KiB
C++
#pragma once
|
|
|
|
#include <ATen/core/builtin_function.h>
|
|
#include <ATen/core/function_schema.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <ATen/core/jit_type.h>
|
|
#include <ATen/core/op_registration/infer_schema.h>
|
|
#include <ATen/core/op_registration/op_registration.h>
|
|
#include <ATen/core/stack.h>
|
|
#include <c10/util/C++17.h>
|
|
#include <c10/util/Metaprogramming.h>
|
|
#include <c10/util/TypeList.h>
|
|
#include <c10/util/TypeTraits.h>
|
|
#include <torch/custom_class_detail.h>
|
|
#include <iostream>
|
|
#include <sstream>
|
|
|
|
namespace torch {
|
|
|
|
TORCH_API at::ClassTypePtr getCustomClass(const std::string& name);
|
|
|
|
TORCH_API bool isCustomClass(const c10::IValue& v);
|
|
|
|
template <class... Types>
|
|
detail::types<void, Types...> init() {
|
|
return detail::types<void, Types...>{};
|
|
}
|
|
|
|
// To bind custom classes into Torchscript, use an API very similar to Pybind's.
|
|
// Currently exposes one class `torch::class_<T>` and 2 methods.
|
|
// - Constructing `torch::class_<Foo>` registers `Foo` in Python and
|
|
// Torchscript, and puts it under `torch.classes.Foo` in Python.
|
|
// - torch::class_<Foo>.def("method1", &Foo::method1) does some template
|
|
// metaprogramming to introspect the function types and register the operator
|
|
// for use in Torchscript.
|
|
// - torch::class_<Foo>.def(torch::init<int64_t, int64_t>()) registers
|
|
// the Foo(int, int) constructor.
|
|
// see test/custom_operator/classes.cpp and
|
|
// test/custom_operator/test_custom_classes.py for example usages
|
|
|
|
template <class CurClass>
|
|
class class_ {
|
|
static_assert(std::is_base_of<CustomClassHolder, CurClass>::value,
|
|
"torch::class_<T> requires T to inherit from CustomClassHolder");
|
|
|
|
std::string className;
|
|
std::string qualClassName;
|
|
at::ClassTypePtr classTypePtr;
|
|
|
|
const std::string parentModule = "classes";
|
|
const std::string topModule = "__torch__.torch";
|
|
|
|
public:
|
|
class_(std::string className_) : className(std::move(className_)) {
|
|
qualClassName = topModule + "." + parentModule + "." + className;
|
|
|
|
// We currently represent custom classes as torchscript classes with a
|
|
// capsule attribute
|
|
classTypePtr = at::ClassType::create(
|
|
c10::QualifiedName(qualClassName),
|
|
std::weak_ptr<jit::CompilationUnit>());
|
|
classTypePtr->addAttribute("capsule", at::CapsuleType::get());
|
|
|
|
c10::getCustomClassTypeMap().insert(
|
|
{typeid(c10::intrusive_ptr<CurClass>).name(), classTypePtr});
|
|
c10::getCustomClassTypeMap().insert(
|
|
{typeid(c10::tagged_capsule<CurClass>).name(), classTypePtr});
|
|
|
|
registerCustomClass(classTypePtr);
|
|
}
|
|
|
|
template <typename... Types>
|
|
class_& def(detail::types<void, Types...>) { // Used in combination with
|
|
// torch::init<...>()
|
|
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
|
|
auto classObj = c10::make_intrusive<CurClass>(args...);
|
|
auto object = self.ivalue.toObject();
|
|
object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
|
|
};
|
|
|
|
defineMethod("__init__", std::move(func));
|
|
return *this;
|
|
}
|
|
template <typename Func>
|
|
class_& def(std::string name, Func f) {
|
|
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
|
|
defineMethod(std::move(name), std::move(wrapped_f));
|
|
return *this;
|
|
}
|
|
|
|
// Pickle
|
|
template <typename GetStateFn, typename SetStateFn>
|
|
class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
|
|
static_assert(
|
|
c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value &&
|
|
c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,
|
|
"def_pickle() currently only supports lambdas as "
|
|
"__getstate__ and __setstate__ arguments.");
|
|
def("__getstate__", std::forward<GetStateFn>(get_state));
|
|
|
|
// __setstate__ needs to be registered with some custom handling:
|
|
// We need to wrap the invocation of of the user-provided function
|
|
// such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)
|
|
// and assign it to the `capsule` attribute.
|
|
using SetStateTraits =
|
|
c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
|
|
using SetStateArg = typename c10::guts::typelist::head_t<
|
|
typename SetStateTraits::parameter_types>;
|
|
auto setstate_wrapper = [set_state = std::move(set_state)](
|
|
c10::tagged_capsule<CurClass> self,
|
|
SetStateArg&& arg) {
|
|
c10::intrusive_ptr<CurClass> classObj =
|
|
at::guts::invoke(set_state, std::forward<SetStateArg>(arg));
|
|
auto object = self.ivalue.toObject();
|
|
object->setSlot(0, c10::IValue::make_capsule(classObj));
|
|
};
|
|
defineMethod(
|
|
"__setstate__",
|
|
detail::wrap_func<CurClass, decltype(setstate_wrapper)>(
|
|
std::move(setstate_wrapper)));
|
|
|
|
// type validation
|
|
auto getstate_schema = classTypePtr->getMethod("__getstate__")->getSchema();
|
|
auto format_getstate_schema = [&getstate_schema]() {
|
|
std::stringstream ss;
|
|
ss << getstate_schema;
|
|
return ss.str();
|
|
};
|
|
TORCH_CHECK(
|
|
getstate_schema.arguments().size() == 1,
|
|
"__getstate__ should take exactly one argument: self. Got: ",
|
|
format_getstate_schema());
|
|
auto first_arg_type = getstate_schema.arguments().at(0).type();
|
|
TORCH_CHECK(
|
|
*first_arg_type == *classTypePtr,
|
|
"self argument of __getstate__ must be the custom class type. Got ",
|
|
first_arg_type->python_str());
|
|
TORCH_CHECK(
|
|
getstate_schema.returns().size() == 1,
|
|
"__getstate__ should return exactly one value for serialization. Got: ",
|
|
format_getstate_schema());
|
|
auto ser_type = getstate_schema.returns().at(0).type();
|
|
auto setstate_schema = classTypePtr->getMethod("__setstate__")->getSchema();
|
|
auto arg_type = setstate_schema.arguments().at(1).type();
|
|
TORCH_CHECK(
|
|
(*arg_type == *ser_type),
|
|
"__setstate__'s argument should be the same type as the "
|
|
"return value of __getstate__. Got ",
|
|
arg_type->python_str(),
|
|
" but expected ",
|
|
ser_type->python_str());
|
|
|
|
return *this;
|
|
}
|
|
|
|
private:
|
|
template <typename Func>
|
|
void defineMethod(std::string name, Func func) {
|
|
auto qualMethodName = qualClassName + "." + name;
|
|
auto schema = c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
|
|
|
|
auto wrapped_func = [func = std::move(func)](jit::Stack& stack) mutable -> void {
|
|
// TODO: we need to figure out how to profile calls to custom functions
|
|
// like this! Currently can't do it because the profiler stuff is in
|
|
// libtorch and not ATen
|
|
using RetType =
|
|
typename c10::guts::infer_function_traits_t<Func>::return_type;
|
|
detail::BoxedProxy<RetType, Func>()(stack, func);
|
|
};
|
|
auto method = std::make_shared<jit::BuiltinOpFunction>(
|
|
qualMethodName, std::move(schema), std::move(wrapped_func));
|
|
|
|
// Register the method here to keep the Method alive.
|
|
// ClassTypes do not hold ownership of their methods (normally it
|
|
// those are held by the CompilationUnit), so we need a proxy for
|
|
// that behavior here.
|
|
registerCustomClassMethod(method);
|
|
classTypePtr->addMethod(method.get());
|
|
}
|
|
};
|
|
|
|
template <typename CurClass, typename... CtorArgs>
|
|
c10::IValue make_custom_class(CtorArgs&&... args) {
|
|
if (!c10::isCustomClassRegistered<c10::intrusive_ptr<CurClass>>()) {
|
|
throw c10::Error(
|
|
"Trying to instantiate a class that isn't a registered custom class.",
|
|
"");
|
|
}
|
|
auto userClassInstance = c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...);
|
|
return c10::IValue(std::move(userClassInstance));
|
|
}
|
|
|
|
// jit namespace for backward-compatibility
|
|
// We previously defined everything in torch::jit but moved it out to
|
|
// better reflect that these features are not limited only to TorchScript
|
|
namespace jit {
|
|
|
|
using ::torch::getCustomClass;
|
|
using ::torch::isCustomClass;
|
|
using ::torch::init;
|
|
using ::torch::class_;
|
|
|
|
} // namespace jit
|
|
|
|
} // namespace torch
|