mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33732 move and forward instead of copy Benchmarks: A microbenchmark calling the add operation on two tensors in a tight loop shows a 5% improvement in performance. No visible change for a model like resnet that does more work in its kernels. ghstack-source-id: 99161486 Test Plan: benchmarks Differential Revision: D20082642 fbshipit-source-id: eeac59686f8621dd5eaa85d61e6d219bba48c847
206 lines
7.8 KiB
C++
206 lines
7.8 KiB
C++
|
|
#pragma once
|
|
|
|
#include <ATen/core/function_schema.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <ATen/core/jit_type.h>
|
|
#include <ATen/core/op_registration/op_registration.h>
|
|
#include <ATen/core/stack.h>
|
|
#include <c10/util/C++17.h>
|
|
#include <c10/util/Metaprogramming.h>
|
|
#include <c10/util/TypeList.h>
|
|
#include <c10/util/TypeTraits.h>
|
|
#include <torch/csrc/jit/api/custom_class.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
|
#include <torch/csrc/jit/frontend/tracer.h>
|
|
#include <torch/csrc/utils/variadic.h>
|
|
#include <torch/custom_class_detail.h>
|
|
#include <iostream>
|
|
#include <sstream>
|
|
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
template <class... Types>
|
|
detail::types<void, Types...> init() {
|
|
return detail::types<void, Types...>{};
|
|
}
|
|
|
|
// To bind custom classes into Torchscript, use an API very similar to Pybind's.
|
|
// Currently exposes one class `torch::jit::class_<T>` and 2 methods.
|
|
// - Constructing `torch::jit::class_<Foo>` registers `Foo` in Python and
|
|
// Torchscript, and puts it under `torch.classes.Foo` in Python.
|
|
// - torch::jit::class_<Foo>.def("method1", &Foo::method1) does some template
|
|
// metaprogramming to introspect the function types and register the operator
|
|
// for use in Torchscript.
|
|
// - torch::jit::class_<Foo>.def(torch::jit::init<int64_t, int64_t>()) registers
|
|
// the Foo(int, int) constructor.
|
|
// see test/custom_operator/classes.cpp and
|
|
// test/custom_operator/test_custom_classes.py for example usages
|
|
|
|
template <class CurClass>
|
|
class class_ {
|
|
static_assert(std::is_base_of<CustomClassHolder, CurClass>::value,
|
|
"torch::jit::class_<T> requires T to inherit from CustomClassHolder");
|
|
|
|
std::string className;
|
|
std::string qualClassName;
|
|
ClassTypePtr classTypePtr;
|
|
|
|
const std::string parentModule = "classes";
|
|
const std::string topModule = "__torch__.torch";
|
|
|
|
public:
|
|
class_(std::string className_) : className(std::move(className_)) {
|
|
qualClassName = topModule + "." + parentModule + "." + className;
|
|
|
|
// We currently represent custom classes as torchscript classes with a
|
|
// capsule attribute
|
|
classTypePtr =
|
|
ClassType::create(c10::QualifiedName(qualClassName), classCU());
|
|
classTypePtr->addAttribute("capsule", CapsuleType::get());
|
|
|
|
c10::getCustomClassTypeMap().insert({typeid(c10::intrusive_ptr<CurClass>).name(),
|
|
c10::StrongTypePtr(classCU(), classTypePtr)});
|
|
c10::getCustomClassTypeMap().insert({typeid(c10::tagged_capsule<CurClass>).name(),
|
|
c10::StrongTypePtr(classCU(), classTypePtr)});
|
|
|
|
classCU()->register_type(classTypePtr);
|
|
}
|
|
|
|
template <typename... Types>
|
|
class_& def(detail::types<void, Types...>) { // Used in combination with
|
|
// torch::jit::init<...>()
|
|
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
|
|
auto classObj = c10::make_intrusive<CurClass>(args...);
|
|
auto genericPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(std::move(classObj));
|
|
auto capsule = IValue(std::move(genericPtr));
|
|
auto object = std::move(self.ivalue).toObject();
|
|
object->setSlot(0, std::move(capsule));
|
|
};
|
|
|
|
defineMethod("__init__", std::move(func));
|
|
return *this;
|
|
}
|
|
template <typename Func>
|
|
class_& def(std::string name, Func f) {
|
|
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
|
|
defineMethod(std::move(name), std::move(wrapped_f));
|
|
return *this;
|
|
}
|
|
|
|
// Pickle
|
|
template <typename GetStateFn, typename SetStateFn>
|
|
class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
|
|
static_assert(
|
|
c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value &&
|
|
c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,
|
|
"torch::jit::pickle_ currently only supports lambdas as "
|
|
"__getstate__ and __setstate__ arguments.");
|
|
def("__getstate__", std::forward<GetStateFn>(get_state));
|
|
|
|
// __setstate__ needs to be registered with some custom handling:
|
|
// We need to wrap the invocation of of the user-provided function
|
|
// such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)
|
|
// and assign it to the `capsule` attribute.
|
|
using SetStateTraits =
|
|
c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
|
|
using SetStateArg = typename c10::guts::typelist::head_t<
|
|
typename SetStateTraits::parameter_types>;
|
|
auto setstate_wrapper = [set_state = std::move(set_state)](
|
|
c10::tagged_capsule<CurClass> self,
|
|
SetStateArg&& arg) {
|
|
c10::intrusive_ptr<CurClass> classObj =
|
|
at::guts::invoke(set_state, std::forward<SetStateArg>(arg));
|
|
auto genericPtr =
|
|
c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(
|
|
classObj);
|
|
auto capsule = IValue(genericPtr);
|
|
auto object = self.ivalue.toObject();
|
|
object->setSlot(0, capsule);
|
|
};
|
|
defineMethod(
|
|
"__setstate__",
|
|
detail::wrap_func<CurClass, decltype(setstate_wrapper)>(
|
|
std::move(setstate_wrapper)));
|
|
|
|
// type validation
|
|
auto getstate_schema = classTypePtr->getMethod("__getstate__")->getSchema();
|
|
auto format_getstate_schema = [&getstate_schema]() {
|
|
std::stringstream ss;
|
|
ss << getstate_schema;
|
|
return ss.str();
|
|
};
|
|
TORCH_CHECK(
|
|
getstate_schema.arguments().size() == 1,
|
|
"__getstate__ should take exactly one argument: self. Got: ",
|
|
format_getstate_schema());
|
|
auto first_arg_type = getstate_schema.arguments().at(0).type();
|
|
TORCH_CHECK(
|
|
*first_arg_type == *classTypePtr,
|
|
"self argument of __getstate__ must be the custom class type. Got ",
|
|
first_arg_type->python_str());
|
|
TORCH_CHECK(
|
|
getstate_schema.returns().size() == 1,
|
|
"__getstate__ should return exactly one value for serialization. Got: ",
|
|
format_getstate_schema());
|
|
auto ser_type = getstate_schema.returns().at(0).type();
|
|
auto setstate_schema = classTypePtr->getMethod("__setstate__")->getSchema();
|
|
auto arg_type = setstate_schema.arguments().at(1).type();
|
|
TORCH_CHECK(
|
|
(*arg_type == *ser_type),
|
|
"__setstate__'s argument should be the same type as the "
|
|
"return value of __getstate__. Got ",
|
|
arg_type->python_str(),
|
|
" but expected ",
|
|
ser_type->python_str());
|
|
|
|
return *this;
|
|
}
|
|
|
|
private:
|
|
template <typename Func>
|
|
void defineMethod(std::string name, Func func) {
|
|
auto graph = std::make_shared<Graph>();
|
|
auto qualFuncName = className + "::" + name;
|
|
ensure_c10_registerer_defined();
|
|
registeredOps().push_back(
|
|
torch::RegisterOperators().op(qualFuncName, std::move(func)));
|
|
auto func_symbol = c10::Symbol::fromQualString(qualFuncName);
|
|
auto ops = torch::jit::getAllOperatorsFor(func_symbol);
|
|
TORCH_CHECK(ops.size() == 1);
|
|
auto &schema = ops[0]->schema();
|
|
|
|
for (const auto& arg : schema.arguments()) {
|
|
graph->addInput()->setType(arg.type());
|
|
}
|
|
|
|
auto opCall = graph->insertNode(graph->create(
|
|
func_symbol, graph->inputs(), schema.returns().size()));
|
|
Value* res;
|
|
if (schema.returns().size() > 1) {
|
|
const auto& returns = schema.returns();
|
|
size_t op_invocation_idx = 0;
|
|
for (const auto& ret : returns) {
|
|
opCall->output(op_invocation_idx++)->setType(ret.type());
|
|
}
|
|
res = graph->insertNode(graph->createTuple(opCall->outputs()))->output();
|
|
} else if (schema.returns().size() == 1) {
|
|
const auto& returns = schema.returns();
|
|
res = opCall->output()->setType(returns[0].type());
|
|
} else {
|
|
res = graph->insertConstant(IValue())->setType(NoneType::get());
|
|
}
|
|
graph->registerOutput(res);
|
|
|
|
auto method = classCU()->create_function(qualClassName + "." + name, graph);
|
|
classTypePtr->addMethod(method);
|
|
}
|
|
};
|
|
|
|
} // namespace jit
|
|
|
|
} // namespace torch
|