mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67340 Currently Torchbind classes arent selective. This makes is a rough granularity pass that will remove entire classes if they arent selected. If we need finer granularity in the future we can make individual methods within classes Selective though instrumenting that will be significantly more involved I think. On a linux build only __torch__.torch.classes._nnapi.Compilation remains unselective. I cant find where its registered :P (theres a couple Android only ones and presumably some metal only ones as well) Many of the classes registered in functions returned a reference to the class that was created. I talked with dreiss about it and we decided that this seemingly didnt serve any purpose, and leaving it like that would make the return value difficult (but possible) to create with selectivity. Since it seems useless anyway I just changed them to return an int so that they can still be called from a global scope, but not have any issues with the return type. ghstack-source-id: 141690776 Test Plan: CI, model unit tests, test models in prod apps Reviewed By: dhruvbird Differential Revision: D31092564 fbshipit-source-id: 657f7eb83490292436c15cf134ceca9b72fb9e1a
463 lines
18 KiB
C++
463 lines
18 KiB
C++
#pragma once
|
|
|
|
#include <ATen/core/stack.h>
|
|
#include <ATen/core/builtin_function.h>
|
|
#include <ATen/core/function_schema.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <ATen/core/jit_type.h>
|
|
#include <ATen/core/op_registration/infer_schema.h>
|
|
#include <ATen/core/stack.h>
|
|
#include <c10/util/C++17.h>
|
|
#include <c10/util/Metaprogramming.h>
|
|
#include <c10/util/TypeList.h>
|
|
#include <c10/util/TypeTraits.h>
|
|
#include <torch/library.h>
|
|
#include <torch/custom_class_detail.h>
|
|
#include <iostream>
|
|
#include <sstream>
|
|
|
|
namespace torch {
|
|
|
|
/// This function is used in conjunction with `class_::def()` to register
|
|
/// a constructor for a given C++ class type. For example,
|
|
/// `torch::init<int, std::string>()` would register a two-argument constructor
|
|
/// taking an `int` and a `std::string` as argument.
|
|
template <class... Types>
|
|
detail::types<void, Types...> init() {
|
|
return detail::types<void, Types...>{};
|
|
}
|
|
|
|
template <typename Func, typename... ParameterTypeList>
|
|
struct InitLambda {
|
|
Func f;
|
|
};
|
|
|
|
template <typename Func>
|
|
decltype(auto) init(Func&& f) {
|
|
using InitTraits =
|
|
c10::guts::infer_function_traits_t<std::decay_t<Func>>;
|
|
using ParameterTypeList = typename InitTraits::parameter_types;
|
|
|
|
InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
|
|
return init;
|
|
}
|
|
|
|
/// Entry point for custom C++ class registration. To register a C++ class
|
|
/// in PyTorch, instantiate `torch::class_` with the desired class as the
|
|
/// template parameter. Typically, this instantiation should be done in
|
|
/// the initialization of a global variable, so that the class will be
|
|
/// made available on dynamic library loading without any additional API
|
|
/// calls needed. For example, to register a class named Foo, you might
|
|
/// create a global variable like so:
|
|
///
|
|
/// static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
|
|
/// .def("myMethod", &Foo::myMethod)
|
|
/// .def("lambdaMethod", [](const c10::intrusive_ptr<Foo>& self) {
|
|
/// // Do something with `self`
|
|
/// });
|
|
///
|
|
/// In addition to registering the class, this registration also chains
|
|
/// `def()` calls to register methods. `myMethod()` is registered with
|
|
/// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()`
|
|
/// is registered with a C++ lambda expression.
|
|
template <class CurClass>
|
|
class class_ : public ::torch::detail::class_base {
|
|
static_assert(std::is_base_of<CustomClassHolder, CurClass>::value,
|
|
"torch::class_<T> requires T to inherit from CustomClassHolder");
|
|
|
|
public:
|
|
/// This constructor actually registers the class type.
|
|
/// String argument `namespaceName` is an identifier for the
|
|
/// namespace you would like this class to appear in.
|
|
/// String argument `className` is the name you would like to
|
|
/// see this class exposed as in Python and TorchScript. For example, if
|
|
/// you pass `foo` as the namespace name and `Bar` as the className, the
|
|
/// class will appear as `torch.classes.foo.Bar` in Python and TorchScript
|
|
explicit class_(const std::string& namespaceName, const std::string& className, std::string doc_string = "")
|
|
: class_base(namespaceName, className, std::move(doc_string), typeid(c10::intrusive_ptr<CurClass>), typeid(c10::tagged_capsule<CurClass>)) {}
|
|
|
|
/// def() can be used in conjunction with `torch::init()` to register
|
|
/// a constructor for a given C++ class type. For example, passing
|
|
/// `torch::init<int, std::string>()` would register a two-argument constructor
|
|
/// taking an `int` and a `std::string` as argument.
|
|
template <typename... Types>
|
|
class_& def(
|
|
torch::detail::types<void, Types...>,
|
|
std::string doc_string = "",
|
|
std::initializer_list<arg> default_args = {}) { // Used in combination with
|
|
// torch::init<...>()
|
|
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
|
|
auto classObj = c10::make_intrusive<CurClass>(args...);
|
|
auto object = self.ivalue.toObject();
|
|
object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
|
|
};
|
|
|
|
defineMethod(
|
|
"__init__",
|
|
std::move(func),
|
|
std::move(doc_string),
|
|
std::move(default_args));
|
|
return *this;
|
|
}
|
|
|
|
// Used in combination with torch::init([]lambda(){......})
|
|
template <typename Func, typename... ParameterTypes>
|
|
class_& def(
|
|
InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
|
|
std::string doc_string = "",
|
|
std::initializer_list<arg> default_args = {}) {
|
|
auto init_lambda_wrapper = [func = std::move(init.f)](
|
|
c10::tagged_capsule<CurClass> self,
|
|
ParameterTypes... arg) {
|
|
c10::intrusive_ptr<CurClass> classObj =
|
|
at::guts::invoke(func, std::forward<ParameterTypes>(arg)...);
|
|
auto object = self.ivalue.toObject();
|
|
object->setSlot(0, c10::IValue::make_capsule(classObj));
|
|
};
|
|
|
|
defineMethod(
|
|
"__init__",
|
|
std::move(init_lambda_wrapper),
|
|
std::move(doc_string),
|
|
std::move(default_args));
|
|
|
|
return *this;
|
|
}
|
|
|
|
/// This is the normal method registration API. `name` is the name that
|
|
/// the method will be made accessible by in Python and TorchScript.
|
|
/// `f` is a callable object that defines the method. Typically `f`
|
|
/// will either be a pointer to a method on `CurClass`, or a lambda
|
|
/// expression that takes a `c10::intrusive_ptr<CurClass>` as the first
|
|
/// argument (emulating a `this` argument in a C++ method.)
|
|
///
|
|
/// Examples:
|
|
///
|
|
/// // Exposes method `foo` on C++ class `Foo` as `call_foo()` in
|
|
/// // Python and TorchScript
|
|
/// .def("call_foo", &Foo::foo)
|
|
///
|
|
/// // Exposes the given lambda expression as method `call_lambda()`
|
|
/// // in Python and TorchScript.
|
|
/// .def("call_lambda", [](const c10::intrusive_ptr<Foo>& self) {
|
|
/// // do something
|
|
/// })
|
|
template <typename Func>
|
|
class_& def(
|
|
std::string name,
|
|
Func f,
|
|
std::string doc_string = "",
|
|
std::initializer_list<arg> default_args = {}) {
|
|
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
|
|
defineMethod(
|
|
std::move(name),
|
|
std::move(wrapped_f),
|
|
std::move(doc_string),
|
|
std::move(default_args));
|
|
return *this;
|
|
}
|
|
|
|
/// Method registration API for static methods.
|
|
template <typename Func>
|
|
class_& def_static(std::string name, Func func, std::string doc_string = "") {
|
|
auto qualMethodName = qualClassName + "." + name;
|
|
auto schema =
|
|
c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
|
|
|
|
auto wrapped_func =
|
|
[func = std::move(func)](jit::Stack& stack) mutable -> void {
|
|
using RetType =
|
|
typename c10::guts::infer_function_traits_t<Func>::return_type;
|
|
detail::BoxedProxy<RetType, Func>()(stack, func);
|
|
};
|
|
auto method = std::make_unique<jit::BuiltinOpFunction>(
|
|
std::move(qualMethodName),
|
|
std::move(schema),
|
|
std::move(wrapped_func),
|
|
std::move(doc_string));
|
|
|
|
classTypePtr->addStaticMethod(method.get());
|
|
registerCustomClassMethod(std::move(method));
|
|
return *this;
|
|
}
|
|
|
|
/// Property registration API for properties with both getter and setter
|
|
/// functions.
|
|
template <typename GetterFunc, typename SetterFunc>
|
|
class_& def_property(
|
|
const std::string& name,
|
|
GetterFunc getter_func,
|
|
SetterFunc setter_func,
|
|
std::string doc_string = "") {
|
|
torch::jit::Function* getter;
|
|
torch::jit::Function* setter;
|
|
|
|
auto wrapped_getter =
|
|
detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
|
|
getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
|
|
|
|
auto wrapped_setter =
|
|
detail::wrap_func<CurClass, SetterFunc>(std::move(setter_func));
|
|
setter = defineMethod(name + "_setter", wrapped_setter, doc_string);
|
|
|
|
classTypePtr->addProperty(name, getter, setter);
|
|
return *this;
|
|
}
|
|
|
|
/// Property registration API for properties with only getter function.
|
|
template <typename GetterFunc>
|
|
class_& def_property(
|
|
const std::string& name,
|
|
GetterFunc getter_func,
|
|
std::string doc_string = "") {
|
|
torch::jit::Function* getter;
|
|
|
|
auto wrapped_getter =
|
|
detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
|
|
getter = defineMethod(name + "_getter", wrapped_getter, doc_string);
|
|
|
|
classTypePtr->addProperty(name, getter, nullptr);
|
|
return *this;
|
|
}
|
|
|
|
/// Property registration API for properties with read-write access.
|
|
template <typename T>
|
|
class_& def_readwrite(const std::string& name, T CurClass::*field) {
|
|
auto getter_func = [field =
|
|
field](const c10::intrusive_ptr<CurClass>& self) {
|
|
return self.get()->*field;
|
|
};
|
|
|
|
auto setter_func = [field = field](
|
|
const c10::intrusive_ptr<CurClass>& self, T value) {
|
|
self.get()->*field = value;
|
|
};
|
|
|
|
return def_property(name, getter_func, setter_func);
|
|
}
|
|
|
|
/// Property registration API for properties with read-only access.
|
|
template <typename T>
|
|
class_& def_readonly(const std::string& name, T CurClass::*field) {
|
|
auto getter_func =
|
|
[field = std::move(field)](const c10::intrusive_ptr<CurClass>& self) {
|
|
return self.get()->*field;
|
|
};
|
|
|
|
return def_property(name, getter_func);
|
|
}
|
|
|
|
/// This is an unsafe method registration API added for adding custom JIT backend support via custom
|
|
/// C++ classes. It is not for general purpose use.
|
|
class_& _def_unboxed(std::string name, std::function<void(jit::Stack&)> func, c10::FunctionSchema schema, std::string doc_string = "") {
|
|
auto method = std::make_unique<jit::BuiltinOpFunction>(
|
|
qualClassName + "." + name, std::move(schema), std::move(func), std::move(doc_string));
|
|
classTypePtr->addMethod(method.get());
|
|
registerCustomClassMethod(std::move(method));
|
|
return *this;
|
|
}
|
|
|
|
/// def_pickle() is used to define exactly what state gets serialized
|
|
/// or deserialized for a given instance of a custom C++ class in
|
|
/// Python or TorchScript. This protocol is equivalent to the Pickle
|
|
/// concept of `__getstate__` and `__setstate__` from Python
|
|
/// (https://docs.python.org/2/library/pickle.html#object.__getstate__)
|
|
///
|
|
/// Currently, both the `get_state` and `set_state` callables must be
|
|
/// C++ lambda expressions. They should have the following signatures,
|
|
/// where `CurClass` is the class you're registering and `T1` is some object
|
|
/// that encapsulates the state of the object.
|
|
///
|
|
/// __getstate__(intrusive_ptr<CurClass>) -> T1
|
|
/// __setstate__(T2) -> intrusive_ptr<CurClass>
|
|
///
|
|
/// `T1` must be an object that is convertable to IValue by the same rules
|
|
/// for custom op/method registration.
|
|
///
|
|
/// For the common case, T1 == T2. T1 can also be a subtype of T2. An
|
|
/// example where it makes sense for T1 and T2 to differ is if __setstate__
|
|
/// handles legacy formats in a backwards compatible way.
|
|
///
|
|
/// Example:
|
|
///
|
|
/// .def_pickle(
|
|
/// // __getstate__
|
|
/// [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
|
|
/// return self->stack_;
|
|
/// },
|
|
/// [](std::vector<std::string> state) { // __setstate__
|
|
/// return c10::make_intrusive<MyStackClass<std::string>>(
|
|
/// std::vector<std::string>{"i", "was", "deserialized"});
|
|
/// })
|
|
template <typename GetStateFn, typename SetStateFn>
|
|
class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
|
|
static_assert(
|
|
c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value &&
|
|
c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,
|
|
"def_pickle() currently only supports lambdas as "
|
|
"__getstate__ and __setstate__ arguments.");
|
|
def("__getstate__", std::forward<GetStateFn>(get_state));
|
|
|
|
// __setstate__ needs to be registered with some custom handling:
|
|
// We need to wrap the invocation of of the user-provided function
|
|
// such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)
|
|
// and assign it to the `capsule` attribute.
|
|
using SetStateTraits =
|
|
c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
|
|
using SetStateArg = typename c10::guts::typelist::head_t<
|
|
typename SetStateTraits::parameter_types>;
|
|
auto setstate_wrapper = [set_state = std::move(set_state)](
|
|
c10::tagged_capsule<CurClass> self,
|
|
SetStateArg&& arg) {
|
|
c10::intrusive_ptr<CurClass> classObj =
|
|
at::guts::invoke(set_state, std::forward<SetStateArg>(arg));
|
|
auto object = self.ivalue.toObject();
|
|
object->setSlot(0, c10::IValue::make_capsule(classObj));
|
|
};
|
|
defineMethod(
|
|
"__setstate__",
|
|
detail::wrap_func<CurClass, decltype(setstate_wrapper)>(
|
|
std::move(setstate_wrapper)));
|
|
|
|
// type validation
|
|
auto getstate_schema = classTypePtr->getMethod("__getstate__").getSchema();
|
|
auto format_getstate_schema = [&getstate_schema]() {
|
|
std::stringstream ss;
|
|
ss << getstate_schema;
|
|
return ss.str();
|
|
};
|
|
TORCH_CHECK(
|
|
getstate_schema.arguments().size() == 1,
|
|
"__getstate__ should take exactly one argument: self. Got: ",
|
|
format_getstate_schema());
|
|
auto first_arg_type = getstate_schema.arguments().at(0).type();
|
|
TORCH_CHECK(
|
|
*first_arg_type == *classTypePtr,
|
|
"self argument of __getstate__ must be the custom class type. Got ",
|
|
first_arg_type->repr_str());
|
|
TORCH_CHECK(
|
|
getstate_schema.returns().size() == 1,
|
|
"__getstate__ should return exactly one value for serialization. Got: ",
|
|
format_getstate_schema());
|
|
|
|
auto ser_type = getstate_schema.returns().at(0).type();
|
|
auto setstate_schema = classTypePtr->getMethod("__setstate__").getSchema();
|
|
auto arg_type = setstate_schema.arguments().at(1).type();
|
|
TORCH_CHECK(
|
|
ser_type->isSubtypeOf(*arg_type),
|
|
"__getstate__'s return type should be a subtype of "
|
|
"input argument of __setstate__. Got ",
|
|
ser_type->repr_str(),
|
|
" but expected ",
|
|
arg_type->repr_str());
|
|
|
|
return *this;
|
|
}
|
|
|
|
private:
|
|
template <typename Func>
|
|
torch::jit::Function* defineMethod(
|
|
std::string name,
|
|
Func func,
|
|
std::string doc_string = "",
|
|
std::initializer_list<arg> default_args = {}) {
|
|
auto qualMethodName = qualClassName + "." + name;
|
|
auto schema = c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
|
|
|
|
// If default values are provided for function arguments, there must be
|
|
// none (no default values) or default values for all function
|
|
// arguments, except for self. This is because argument names are not
|
|
// extracted by inferFunctionSchemaSingleReturn, and so there must be a
|
|
// torch::arg instance in default_args even for arguments that do not
|
|
// have an actual default value provided.
|
|
TORCH_CHECK(
|
|
default_args.size() == 0 ||
|
|
default_args.size() == schema.arguments().size() - 1,
|
|
"Default values must be specified for none or all arguments");
|
|
|
|
// If there are default args, copy the argument names and default values to the
|
|
// function schema.
|
|
if (default_args.size() > 0) {
|
|
schema = withNewArguments(schema, default_args);
|
|
}
|
|
|
|
auto wrapped_func =
|
|
[func = std::move(func)](jit::Stack& stack) mutable -> void {
|
|
// TODO: we need to figure out how to profile calls to custom functions
|
|
// like this! Currently can't do it because the profiler stuff is in
|
|
// libtorch and not ATen
|
|
using RetType =
|
|
typename c10::guts::infer_function_traits_t<Func>::return_type;
|
|
detail::BoxedProxy<RetType, Func>()(stack, func);
|
|
};
|
|
auto method = std::make_unique<jit::BuiltinOpFunction>(
|
|
qualMethodName, std::move(schema), std::move(wrapped_func), std::move(doc_string));
|
|
|
|
// Register the method here to keep the Method alive.
|
|
// ClassTypes do not hold ownership of their methods (normally it
|
|
// those are held by the CompilationUnit), so we need a proxy for
|
|
// that behavior here.
|
|
auto method_val = method.get();
|
|
classTypePtr->addMethod(method_val);
|
|
registerCustomClassMethod(std::move(method));
|
|
return method_val;
|
|
}
|
|
};
|
|
|
|
/// make_custom_class() is a convenient way to create an instance of a registered
|
|
/// custom class and wrap it in an IValue, for example when you want to pass the
|
|
/// object to TorchScript. Its syntax is equivalent to APIs like `std::make_shared<>`
|
|
/// or `c10::make_intrusive<>`.
|
|
///
|
|
/// For example, if you have a custom C++ class that can be constructed from an `int`
|
|
/// and `std::string`, you might use this API like so:
|
|
///
|
|
/// IValue custom_class_iv = torch::make_custom_class<MyClass>(3, "foobarbaz");
|
|
template <typename CurClass, typename... CtorArgs>
|
|
c10::IValue make_custom_class(CtorArgs&&... args) {
|
|
auto userClassInstance = c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...);
|
|
return c10::IValue(std::move(userClassInstance));
|
|
}
|
|
|
|
// jit namespace for backward-compatibility
|
|
// We previously defined everything in torch::jit but moved it out to
|
|
// better reflect that these features are not limited only to TorchScript
|
|
namespace jit {
|
|
|
|
using ::torch::getCustomClass;
|
|
using ::torch::isCustomClass;
|
|
using ::torch::init;
|
|
using ::torch::class_;
|
|
|
|
} // namespace jit
|
|
|
|
template <class CurClass>
|
|
inline class_<CurClass> Library::class_(const std::string& className) {
|
|
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
|
|
"class_(\"", className, "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
|
|
"All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
|
|
"(Error occurred at ", file_, ":", line_, ")");
|
|
TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
|
|
return torch::class_<CurClass>(*ns_, className);
|
|
}
|
|
|
|
const std::unordered_set<std::string> getAllCustomClassesNames();
|
|
|
|
template <class CurClass>
|
|
inline class_<CurClass> Library::class_(detail::SelectiveStr<true> className) {
|
|
auto class_name = std::string(className.operator const char *());
|
|
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
|
|
"class_(\"", class_name, "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
|
|
"All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
|
|
"(Error occurred at ", file_, ":", line_, ")");
|
|
TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
|
|
return torch::class_<CurClass>(*ns_, class_name);
|
|
}
|
|
|
|
template <class CurClass>
|
|
inline detail::ClassNotSelected Library::class_(detail::SelectiveStr<false>) {
|
|
return detail::ClassNotSelected();
|
|
}
|
|
|
|
} // namespace torch
|