#pragma once #include #include #include #include namespace torch { namespace jit { namespace { // NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration) inline c10::FunctionSchema getIsAvailableSchema() { c10::Argument self("self", c10::AnyType::get()); c10::Argument available("available", c10::BoolType::get()); c10::FunctionSchema preprocessor_schema( "is_available", /*overload_name=*/"", /*arguments=*/{self}, /*returns=*/{available}); return preprocessor_schema; } constexpr static auto kBackendsNamespace = "__backends__"; // NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration) inline c10::FunctionSchema getCompileSchema() { c10::Argument self("self", c10::AnyType::get()); c10::Argument mod("processed", c10::AnyType::get()); auto any_dict_ty = c10::DictType::create(c10::StringType::get(), c10::AnyType::get()); c10::Argument method_compile_spec("method_compile_spec", any_dict_ty); c10::Argument handles("handles", any_dict_ty); c10::FunctionSchema compile_schema( "compile", /*overload_name=*/"", /*arguments=*/{self, mod, method_compile_spec}, /*returns=*/{handles}); return compile_schema; } // NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration) inline c10::FunctionSchema getExecuteSchema() { auto any_list_ty = c10::ListType::create(c10::AnyType::get()); c10::Argument self("self", c10::AnyType::get()); c10::Argument handle("handle", c10::AnyType::get()); c10::Argument input("input", any_list_ty); c10::Argument output("output", any_list_ty); return c10::FunctionSchema( "execute", /*overload_name=*/"", /*arguments=*/{self, handle, input}, /*returns=*/{output}); } template std::function getIsAvailableFunc() { return [](Stack& stack) { auto self = pop(stack).toCustomClass(); auto ret = self->is_available(); push(stack, ret); }; } template std::function getCompileFunc() { return [](Stack& stack) { auto method_compile_spec = pop(stack).toGenericDict(); auto processed = pop(stack); auto self = pop(stack).toCustomClass(); auto ret = self->compile(processed, method_compile_spec); push(stack, ret); }; } template std::function getExecuteFunc() { return [](Stack& stack) { auto args = pop(stack); auto handle = pop(stack); auto self = pop(stack); auto backend = self.toCustomClass(); auto res = backend->execute(handle, args.toList()); push(stack, res); }; } } // namespace // Static registration API for backends. template class backend { static_assert( std::is_base_of::value, "torch::jit::backend requires T to inherit from PyTorchBackendInterface"); std::string backend_name_; public: // Registers a new backend with /p name, and the given /p preprocess // function. backend(const std::string& name) : backend_name_(name) { static auto cls = torch::class_(kBackendsNamespace, name) .def(torch::init<>()) ._def_unboxed( "is_available", getIsAvailableFunc(), getIsAvailableSchema()) ._def_unboxed( "compile", getCompileFunc(), getCompileSchema()) ._def_unboxed( "execute", getExecuteFunc(), getExecuteSchema()); } }; } // namespace jit } // namespace torch