diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 7d85211e9c2..5794226b14a 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include @@ -373,7 +373,7 @@ Therefore, for the moment, this is all copy pasted in from VariableTypeEverythin Explicit registration for out-of-place ops *****************************************/ TORCH_LIBRARY_IMPL(_, Autocast, m) { - m.fallback(torch::CppFunction::makeFallthrough()); + m.fallback(c10::CppFunction::makeFallthrough()); } TORCH_LIBRARY_IMPL(aten, Autocast, m) { diff --git a/aten/src/ATen/core/BackendSelectFallbackKernel.cpp b/aten/src/ATen/core/BackendSelectFallbackKernel.cpp index a00b5fdc64a..87155ec46b8 100644 --- a/aten/src/ATen/core/BackendSelectFallbackKernel.cpp +++ b/aten/src/ATen/core/BackendSelectFallbackKernel.cpp @@ -1,5 +1,9 @@ -#include +#include + +namespace { TORCH_LIBRARY_IMPL(_, BackendSelect, m) { - m.fallback(torch::CppFunction::makeFallthrough()); + m.fallback(c10::CppFunction::makeFallthrough()); +} + } diff --git a/aten/src/ATen/core/VariableFallbackKernel.cpp b/aten/src/ATen/core/VariableFallbackKernel.cpp index 30f173fb883..45a3b10ebe6 100644 --- a/aten/src/ATen/core/VariableFallbackKernel.cpp +++ b/aten/src/ATen/core/VariableFallbackKernel.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include /* * This file implements a variable fallback kernel for custom operators. @@ -65,9 +65,9 @@ TORCH_LIBRARY_IMPL(_, Autograd, m) { // // We can remove this `fallthrough` kernel when all kernels support boxed // call. - m.fallback(torch::CppFunction::makeFallthrough()); + m.fallback(c10::CppFunction::makeFallthrough()); #else - m.fallback(torch::CppFunction::makeFromBoxedFunction<&variable_fallback_kernel>()); + m.fallback(c10::CppFunction::makeFromBoxedFunction<&variable_fallback_kernel>()); #endif } diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index c46ec1f2013..81e50e229aa 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include namespace c10 { diff --git a/aten/src/ATen/core/library.cpp b/aten/src/ATen/core/library.cpp deleted file mode 100644 index 5ec2a48aec3..00000000000 --- a/aten/src/ATen/core/library.cpp +++ /dev/null @@ -1,232 +0,0 @@ -#include - -namespace torch { - -namespace { - // TODO: Consider representing debug info as a struct instead so you - // don't have to allocate strings all the time - std::string debugString(std::string debug, const char* file, uint32_t line) { -#ifdef STRIP_ERROR_MESSAGES - return ""; -#else - if (debug.empty()) { - return c10::str("registered at ", file, ":", line); - } else { - return debug; - } -#endif - } - - std::ostream& operator<<(std::ostream& os, Library::Kind kind) { - switch (kind) { - case Library::DEF: - os << "TORCH_LIBRARY"; - break; - case Library::IMPL: - os << "TORCH_LIBRARY_IMPL"; - break; - case Library::FRAGMENT: - os << "TORCH_LIBRARY_FRAGMENT"; - break; - } - return os; - } -} - -CppFunction::CppFunction(c10::KernelFunction func, std::unique_ptr schema) - : func_(std::move(func)) - , schema_(std::move(schema)) - , debug_() - {} - -#define ERROR_CONTEXT "(Error occurred while processing ", kind_, " block at ", file_, ":", line_, ")" - -Library::Library(Kind kind, std::string ns, c10::optional k, const char* file, uint32_t line) - : kind_(kind) - , ns_(ns == "_" ? c10::nullopt : c10::make_optional(std::move(ns))) - , dispatch_key_((!k.has_value() || *k == c10::DispatchKey::CatchAll) ? c10::nullopt : k) - , file_(file) - , line_(line) - { - switch (kind_) { - case DEF: - // Only DEFs require library uniqueness; fragments - // don't register a library - registrars_.emplace_back( - c10::Dispatcher::singleton().registerLibrary( - *ns_, debugString("", file_, line_) - ) - ); - // fallthrough - case FRAGMENT: - TORCH_CHECK( - ns_.has_value(), - kind_, ": cannot define ", kind_, " with the wildcard namespace _ " - "(every ", kind_, " defines operators for a distinct namespace!)" - "Did you mean to use TORCH_LIBRARY_IMPL instead? " - ERROR_CONTEXT - ); - TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT); - break; - case IMPL: - // Nothing to do, everything is OK - break; - } - } - -// TODO: Error if an operator is def'ed multiple times. Right now we just -// merge everything - -#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): " -Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name) & { - TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT, - DEF_PRELUDE, - "Cannot define an operator inside of a ", kind_, " block. " - "All def()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. ", - ERROR_CONTEXT - ); - TORCH_INTERNAL_ASSERT(ns_.has_value(), ERROR_CONTEXT); - TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT); - auto ns_opt = schema.getNamespace(); - if (ns_opt.has_value()) { - // Note [Redundancy in registration code is OK] - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // In an earlier version of this code, I made it an error to explicitly - // specify the namespace, even when the namespaces match. I've decided - // to relax this constraint because sometimes we code generate registrations - // and you cannot conveniently tell what the enclosing context will be; - // in these cases, it is simpler (and less error prone) to place all - // of the information in the registration site, which will be cross-checked - // in the end in any case (and if it turns out you DON'T have the right - // information at the site, as is the case with backend specific - // per-op registrations, you will get the right behavior!) - TORCH_CHECK(false, - *ns_opt == *ns_, - "Explicitly provided namespace (", *ns_opt, ") in schema string " - "does not match namespace of enclsing ", kind_, " block (", *ns_, "). " - "Move this definition to the (unique) TORCH_LIBRARY block corresponding to this namespace " - "(and consider deleting the namespace from your schema string.) ", - ERROR_CONTEXT - ); - } else { - bool b = schema.setNamespaceIfNotSet(ns_->c_str()); - TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT); - } - if (out_name) { - *out_name = schema.operator_name(); // copy! - } - registrars_.emplace_back( - c10::Dispatcher::singleton().registerDef( - std::move(schema), - debugString("", file_, line_) - ) - ); - return *this; -} -#undef DEF_PRELUDE - -Library& Library::_def(c10::either&& name_or_schema, CppFunction&& f) & { - c10::FunctionSchema schema = [&] { - if (name_or_schema.is_right()) { - return std::move(name_or_schema).right(); - } else { - // it's a name; use the inferred schema - c10::OperatorName name = std::move(name_or_schema).left(); - TORCH_CHECK(f.schema_, - "def(\"", name, "\"): " - "Full schema string was not specified, and we couldn't infer schema either. ", - "Please explicitly provide a schema string. ", - ERROR_CONTEXT - ); - c10::FunctionSchema s = f.schema_->cloneWithName(std::move(name.name), std::move(name.overload_name)); - s.setAliasAnalysis(c10::AliasAnalysisKind::CONSERVATIVE); - return s; - } - }(); - c10::OperatorName name("", ""); // Get the namespaced name for the impl call - // First define the schema... - _def(std::move(schema), &name); - // Then register the implementation... - auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_; - registrars_.emplace_back( - c10::Dispatcher::singleton().registerImpl( - std::move(name), - dispatch_key, - std::move(f.func_), - std::move(f.schema_), - debugString(std::move(f.debug_), file_, line_) - ) - ); - return *this; -} - -#define IMPL_PRELUDE "impl(\"", name_str, "\", ...): " -Library& Library::_impl(const char* name_str, CppFunction&& f) & { - auto name = torch::jit::parseName(name_str); - auto ns_opt = name.getNamespace(); - // This is kind of similar to the checking in def(), but the error - // messages are a little different for this call site - if (ns_opt.has_value()) { - // See Note [Redundancy in registration code is OK] - TORCH_CHECK(*ns_opt == *ns_, - IMPL_PRELUDE, - "Explicitly provided namespace (", *ns_opt, ") in operator name " - "does not match namespace of enclosing ", kind_, " block (", *ns_, "). " - "Move this definition to the ", kind_, " block corresponding to this namespace " - "(and consider deleting the namespace from your schema string.) ", - ERROR_CONTEXT - ); - } else { - bool b = name.setNamespaceIfNotSet(ns_->c_str()); - TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT); - } - // See Note [Redundancy in registration code is OK] - TORCH_CHECK(!(f.dispatch_key_.has_value() && - dispatch_key_.has_value() && - *f.dispatch_key_ != *dispatch_key_), - IMPL_PRELUDE, - "Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent " - "with the dispatch key of the enclosing ", kind_, " block (", *dispatch_key_, "). " - "Please declare a separate ", kind_, " block for this dispatch key and " - "move your impl() there. " - ERROR_CONTEXT - ); - auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_; - registrars_.emplace_back( - c10::Dispatcher::singleton().registerImpl( - std::move(name), - dispatch_key, - std::move(f.func_), - std::move(f.schema_), - debugString(std::move(f.debug_), file_, line_) - ) - ); - return *this; -} -#undef IMPL_PRELUDE - -Library& Library::_fallback(CppFunction&& f) & { - TORCH_CHECK(kind_ == IMPL, - "fallback(...): Cannot define an operator inside of a ", kind_, " block. " - "Did you mean to call this function inside a TORCH_LIBRARY_IMPL block? ", - ERROR_CONTEXT); - auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_; - TORCH_INTERNAL_ASSERT(dispatch_key.has_value(), ERROR_CONTEXT); - TORCH_CHECK(!ns_.has_value(), - "fallback(...): Fallback functions which apply to only a single namespace ", - "(you specified ", *ns_, ") are not supported. If you intended to apply ", - "this fallback function globally, please define a separate block:\n\n", - " TORCH_LIBRARY_IMPL(_, ", *dispatch_key, ", m) { m.fallback(...); }\n\n", - ERROR_CONTEXT); - registrars_.emplace_back( - c10::Dispatcher::singleton().registerFallback( - *dispatch_key, - std::move(f.func_), - debugString(std::move(f.debug_), file_, line_) - ) - ); - return *this; -} - - -} // namespace torch diff --git a/aten/src/ATen/core/op_registration/op_registration.cpp b/aten/src/ATen/core/op_registration/op_registration.cpp index 75ec26a75bb..815bd7d6886 100644 --- a/aten/src/ATen/core/op_registration/op_registration.cpp +++ b/aten/src/ATen/core/op_registration/op_registration.cpp @@ -7,6 +7,37 @@ namespace c10 { +namespace { + // TODO: Consider representing debug info as a struct instead so you + // don't have to allocate strings all the time + std::string debugString(std::string debug, const char* file, uint32_t line) { +#ifdef STRIP_ERROR_MESSAGES + return ""; +#else + if (debug.empty()) { + return c10::str("registered at ", file, ":", line); + } else { + return debug; + } +#endif + } + + std::ostream& operator<<(std::ostream& os, Library::Kind kind) { + switch (kind) { + case Library::DEF: + os << "TORCH_LIBRARY"; + break; + case Library::IMPL: + os << "TORCH_LIBRARY_IMPL"; + break; + case Library::FRAGMENT: + os << "TORCH_LIBRARY_FRAGMENT"; + break; + } + return os; + } +} + static_assert(std::is_nothrow_move_constructible>::value, ""); static_assert(std::is_nothrow_move_assignable>::value, ""); @@ -109,4 +140,200 @@ RegisterOperators::~RegisterOperators() = default; RegisterOperators::RegisterOperators(RegisterOperators&&) noexcept = default; RegisterOperators& RegisterOperators::operator=(RegisterOperators&&) noexcept = default; -} // namespace c10 + +CppFunction::CppFunction(KernelFunction func, std::unique_ptr schema) + : func_(std::move(func)) + , schema_(std::move(schema)) + , debug_() + {} + +#define ERROR_CONTEXT "(Error occurred while processing ", kind_, " block at ", file_, ":", line_, ")" + +Library::Library(Kind kind, std::string ns, c10::optional k, const char* file, uint32_t line) + : kind_(kind) + , ns_(ns == "_" ? c10::nullopt : c10::make_optional(std::move(ns))) + , dispatch_key_((!k.has_value() || *k == DispatchKey::CatchAll) ? c10::nullopt : k) + , file_(file) + , line_(line) + { + switch (kind_) { + case DEF: + // Only DEFs require library uniqueness; fragments + // don't register a library + registrars_.emplace_back( + Dispatcher::singleton().registerLibrary( + *ns_, debugString("", file_, line_) + ) + ); + // fallthrough + case FRAGMENT: + TORCH_CHECK( + ns_.has_value(), + kind_, ": cannot define ", kind_, " with the wildcard namespace _ " + "(every ", kind_, " defines operators for a distinct namespace!)" + "Did you mean to use TORCH_LIBRARY_IMPL instead? " + ERROR_CONTEXT + ); + TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT); + break; + case IMPL: + // Nothing to do, everything is OK + break; + } + } + +// TODO: Error if an operator is def'ed multiple times. Right now we just +// merge everything + +#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): " +Library& Library::_def(FunctionSchema&& schema, OperatorName* out_name) & { + TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT, + DEF_PRELUDE, + "Cannot define an operator inside of a ", kind_, " block. " + "All def()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. ", + ERROR_CONTEXT + ); + TORCH_INTERNAL_ASSERT(ns_.has_value(), ERROR_CONTEXT); + TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT); + auto ns_opt = schema.getNamespace(); + if (ns_opt.has_value()) { + // Note [Redundancy in registration code is OK] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // In an earlier version of this code, I made it an error to explicitly + // specify the namespace, even when the namespaces match. I've decided + // to relax this constraint because sometimes we code generate registrations + // and you cannot conveniently tell what the enclosing context will be; + // in these cases, it is simpler (and less error prone) to place all + // of the information in the registration site, which will be cross-checked + // in the end in any case (and if it turns out you DON'T have the right + // information at the site, as is the case with backend specific + // per-op registrations, you will get the right behavior!) + TORCH_CHECK(false, + *ns_opt == *ns_, + "Explicitly provided namespace (", *ns_opt, ") in schema string " + "does not match namespace of enclsing ", kind_, " block (", *ns_, "). " + "Move this definition to the (unique) TORCH_LIBRARY block corresponding to this namespace " + "(and consider deleting the namespace from your schema string.) ", + ERROR_CONTEXT + ); + } else { + bool b = schema.setNamespaceIfNotSet(ns_->c_str()); + TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT); + } + if (out_name) { + *out_name = schema.operator_name(); // copy! + } + registrars_.emplace_back( + Dispatcher::singleton().registerDef( + std::move(schema), + debugString("", file_, line_) + ) + ); + return *this; +} +#undef DEF_PRELUDE + +Library& Library::_def(c10::either&& name_or_schema, CppFunction&& f) & { + FunctionSchema schema = [&] { + if (name_or_schema.is_right()) { + return std::move(name_or_schema).right(); + } else { + // it's a name; use the inferred schema + OperatorName name = std::move(name_or_schema).left(); + TORCH_CHECK(f.schema_, + "def(\"", name, "\"): " + "Full schema string was not specified, and we couldn't infer schema either. ", + "Please explicitly provide a schema string. ", + ERROR_CONTEXT + ); + FunctionSchema s = f.schema_->cloneWithName(std::move(name.name), std::move(name.overload_name)); + s.setAliasAnalysis(c10::AliasAnalysisKind::CONSERVATIVE); + return s; + } + }(); + OperatorName name("", ""); // Get the namespaced name for the impl call + // First define the schema... + _def(std::move(schema), &name); + // Then register the implementation... + auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_; + registrars_.emplace_back( + Dispatcher::singleton().registerImpl( + std::move(name), + dispatch_key, + std::move(f.func_), + std::move(f.schema_), + debugString(std::move(f.debug_), file_, line_) + ) + ); + return *this; +} + +#define IMPL_PRELUDE "impl(\"", name_str, "\", ...): " +Library& Library::_impl(const char* name_str, CppFunction&& f) & { + auto name = torch::jit::parseName(name_str); + auto ns_opt = name.getNamespace(); + // This is kind of similar to the checking in def(), but the error + // messages are a little different for this call site + if (ns_opt.has_value()) { + // See Note [Redundancy in registration code is OK] + TORCH_CHECK(*ns_opt == *ns_, + IMPL_PRELUDE, + "Explicitly provided namespace (", *ns_opt, ") in operator name " + "does not match namespace of enclosing ", kind_, " block (", *ns_, "). " + "Move this definition to the ", kind_, " block corresponding to this namespace " + "(and consider deleting the namespace from your schema string.) ", + ERROR_CONTEXT + ); + } else { + bool b = name.setNamespaceIfNotSet(ns_->c_str()); + TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT); + } + // See Note [Redundancy in registration code is OK] + TORCH_CHECK(!(f.dispatch_key_.has_value() && + dispatch_key_.has_value() && + *f.dispatch_key_ != *dispatch_key_), + IMPL_PRELUDE, + "Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent " + "with the dispatch key of the enclosing ", kind_, " block (", *dispatch_key_, "). " + "Please declare a separate ", kind_, " block for this dispatch key and " + "move your impl() there. " + ERROR_CONTEXT + ); + auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_; + registrars_.emplace_back( + Dispatcher::singleton().registerImpl( + std::move(name), + dispatch_key, + std::move(f.func_), + std::move(f.schema_), + debugString(std::move(f.debug_), file_, line_) + ) + ); + return *this; +} +#undef IMPL_PRELUDE + +Library& Library::_fallback(CppFunction&& f) & { + TORCH_CHECK(kind_ == IMPL, + "fallback(...): Cannot define an operator inside of a ", kind_, " block. " + "Did you mean to call this function inside a TORCH_LIBRARY_IMPL block? ", + ERROR_CONTEXT); + auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_; + TORCH_INTERNAL_ASSERT(dispatch_key.has_value(), ERROR_CONTEXT); + TORCH_CHECK(!ns_.has_value(), + "fallback(...): Fallback functions which apply to only a single namespace ", + "(you specified ", *ns_, ") are not supported. If you intended to apply ", + "this fallback function globally, please define a separate block:\n\n", + " TORCH_LIBRARY_IMPL(_, ", *dispatch_key, ", m) { m.fallback(...); }\n\n", + ERROR_CONTEXT); + registrars_.emplace_back( + Dispatcher::singleton().registerFallback( + *dispatch_key, + std::move(f.func_), + debugString(std::move(f.debug_), file_, line_) + ) + ); + return *this; +} + +} diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h index 82c4ea39a44..73cd4d82b1f 100644 --- a/aten/src/ATen/core/op_registration/op_registration.h +++ b/aten/src/ATen/core/op_registration/op_registration.h @@ -594,9 +594,407 @@ private: std::vector registrars_; }; +// -------------------------------------------------------------------------- +// +// New style API +// +// -------------------------------------------------------------------------- +// +// The basic concept behind the new style API is to be as similar to pybind11's +// API as possible. +// +// A quick tour of a few usage examples: +// +// // Define a library whose operators live in the namespace 'aten'. +// // You must define all of the operators for this library in +// // this namespace. +// TORCH_LIBRARY(aten, m) { +// // Define a schema for an operator, but provide no implementation +// m.def("mul(Tensor self, Tensor other) -> Tensor"); +// +// // Define a operator with exactly one implementation for all backends. +// m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl); +// +// // Provide an implementation for a defined operator (you can +// // provide multiple; one per backend). We'll take care of calling +// // the correct implementation depending on if we get a CPU +// // tensor or a CUDA tensor +// m.impl("mul", torch::kCPU, &mul_cpu_impl); +// m.impl("mul", torch::kCUDA, &mul_cuda_impl); +// } +// +// // Define implementations for operators for a non-standard backend, +// // e.g., XLA (valid values are entries of DispatchKey). These +// // operator names are not namespaced; you can define implementations +// // for any namespace. +// TORCH_LIBRARY_IMPL(aten, XLA, m) { +// m.impl("mul", &mul_xla_impl); +// } + + +// Represents a C++ function that implements an operator. Most users won't +// interact directly with this class, except via error messages: the +// constructors this function define the set of permissible "function"-like +// things you can bind via the interface. +// +// This class erases the type of the passed in function, but durably records +// the type via an inferred schema for the function. +// +// TODO: This is morally the same thing as KernelRegistrationConfig, but it's +// opaque to the user. +class CAFFE2_API CppFunction final { +public: + // This overload accepts function pointers, e.g., CppFunction(&add_impl) + template + explicit CppFunction(Func* f, std::enable_if_t::value, std::nullptr_t> = nullptr) + : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)) + // TODO: Don't go through WrapRuntimeKernelFunctor + , schema_(detail::inferFunctionSchemaFromFunctor>>()) + , debug_() + {} + + // This overload accepts lambdas, e.g., CppFunction([](const Tensor& self) { ... }) + template + explicit CppFunction(Lambda&& f, std::enable_if_t>::value, std::nullptr_t> = nullptr) + : func_(c10::KernelFunction::makeFromUnboxedLambda(std::forward(f))) + // TODO: Don't go through WrapRuntimeKernelFunctor + , schema_(detail::inferFunctionSchemaFromFunctor>>()) + , debug_() + {} + + // This static factory lets you create CppFunctions that (1) don't have boxing + // wrappers (because we don't support it yet) and (2) don't have schema + // inference (because some ops don't support it). + // + // TODO: Eliminate the necessity for this function entirely. + template + static CppFunction makeUnboxedOnly(Func* f) { + return CppFunction( + c10::KernelFunction::makeFromUnboxedOnlyRuntimeFunction(f), + /* schema */ nullptr + ); + } + + // TODO: more user friendly API + static CppFunction makeFallthrough() { + return CppFunction( + c10::KernelFunction::makeFallthrough(), + /* schema */ nullptr + ); + } + + // TODO: more user friendly API + template + static CppFunction makeFromBoxedFunction() { + return CppFunction( + c10::KernelFunction::makeFromBoxedFunction(), + /* schema */ nullptr + ); + } + + CppFunction&& debug(std::string d) && { + debug_ = std::move(d); + return std::move(*this); + } + +private: + c10::optional dispatch_key_; + c10::KernelFunction func_; + std::unique_ptr schema_; + std::string debug_; + + // The "setter" for dispatch_key_ + template + friend CppFunction dispatch(c10::DispatchKey, Func&&); + + // The only class which actually pulls out values from CppFunction (does so + // destructively, felt too lazy to write accessors that I don't even + // want users to use) + friend class Library; + + CppFunction(KernelFunction func, std::unique_ptr schema); +}; + +// Create a CppFunction which is associated with a specific dispatch key. +// CppFunctions that are tagged with a DispatchKey don't get invoked /unless/ +// the dispatcher determines that the DispatchKey is the best choice for +// a function +template +inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) { + CppFunction f(std::forward(raw_f)); + if (k == c10::DispatchKey::CatchAll) { + f.dispatch_key_ = c10::nullopt; + } else { + f.dispatch_key_ = k; + } + return f; +} + +// Convenience overload of dispatch which accepts DeviceType +template +inline CppFunction dispatch(DeviceType type, Func&& raw_f) { + auto deviceTypeToDispatchKey = [](DeviceType t){ + switch (t) { + // This list is synchronized with the k-constants in c10/core/DeviceType.h + case DeviceType::CPU: + return c10::DispatchKey::CPU; + case DeviceType::CUDA: + return c10::DispatchKey::CUDA; + case DeviceType::XLA: + return c10::DispatchKey::XLA; + case DeviceType::HIP: + return c10::DispatchKey::HIP; + case DeviceType::MSNPU: + return c10::DispatchKey::MSNPU; + default: + TORCH_CHECK(false, + "Device type ", t, " cannot be overloaded at dispatch time, " + "please file a bug report explaining what you were trying to do."); + } + }; + return dispatch(deviceTypeToDispatchKey(type), std::forward(raw_f)); +} + +inline FunctionSchema schema(const char* str, AliasAnalysisKind k) { + FunctionSchema s = torch::jit::parseSchema(str); + s.setAliasAnalysis(k); + return s; +} +inline FunctionSchema schema(const char* s) { + return schema(s, AliasAnalysisKind::FROM_SCHEMA); +} +inline FunctionSchema&& schema(FunctionSchema&& s) { return std::move(s); } + +namespace detail { + + inline c10::either constructSchemaOrName(FunctionSchema&& s) { + return c10::make_right(std::move(s)); + } + inline c10::either constructSchemaOrName(OperatorName&& n) { + return c10::make_left(std::move(n)); + } + inline c10::either constructSchemaOrName(const char* str) { + auto s = torch::jit::parseSchemaOrName(str); + if (s.is_right()) { + s.right().setAliasAnalysis(AliasAnalysisKind::FROM_SCHEMA); + } + return s; + } + +} + +namespace detail { + class TorchLibraryInit; +} + +// This is the "handle" by which functions defined in TORCH_LIBRARY +// and TORCH_LIBRARY_IMPL can define operators and override implementations +// at certain backends. +// +// Conventionally, you get access to it using those two macros: +// +// TORCH_LIBRARY(torchvision, m) { +// // m is a c10::Library +// m.def("roi_align", ...); +// ... +// } +// +// TORCH_LIBRARY_IMPL(aten, XLA, m) { +// // m is a c10::Library +// m.impl("add", ...); +// ... +// } +// +// In some cases, you need to define something that applies to all namespaces, +// not just one namespace (usually a fallback). In that case, use the reserved +// namespace _, e.g., +// +// TORCH_LIBRARY_IMPL(_, XLA, m) { +// m.fallback(xla_fallback); +// } +// +class CAFFE2_API Library final { +public: + // Which type of macro produced this Library + enum Kind { + DEF, // from TORCH_LIBRARY (no qualifier) + IMPL, + FRAGMENT, + }; + + // Use TORCH_LIBRARY/TORCH_LIBRARY_IMPL instead of these constructors directly + Library(Kind kind, std::string ns, c10::optional k, const char* file, uint32_t line); + + Library(const Library&) = delete; + Library& operator=(const Library&) = delete; + Library(Library&&) = default; + Library& operator=(Library&&) = default; + + // Some notes about the API design here. We had the following constraints: + // + // - We need to support multiple "types" of arguments for schema and + // functions (e.g., unnamed lambda types, regular functions, const char*, + // fully instantiated schemas) + // - We don't want to write exponentially many overloads + // - We don't want to rely on implicit conversion to a common type, + // because the C++ compiler will only be willing to do a single + // implicit conversion (reducing the set of valid types which you + // can invoke with); also error messages are worse when an implicit + // conversion is not selected (as the compiler will not explain + // why it didn't select an implicit conversion; this is different + // from overloads where it will explain each candidate overload and + // why it didn't apply) + // + // To solve all of these constraints at the same time, we use a trick taken + // from the pybind11 library: template over the argument in the user visible + // API, and inside of the templated function explicitly call an overloaded + // function to resolve the argument to a real type. You get the good error + // messages from overloads, but at the same time you only need to write the + // overload for any given argument type once. + + // Declare an operator with a schema, but don't provide any implementations + // for it. You're expected to then provide implementations using the + // impl() method. + template + Library& def(Schema&& raw_schema) & { + FunctionSchema s = schema(std::forward(raw_schema)); + return _def(std::move(s)); + } + + // Convenience method to define an operator for a schema and then register + // an implementation for it. def(n, f) is almost equivalent to def(n).impl(f), + // except that if n is not a schema, then the schema is inferred from the + // static type of f. + template + Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) & { + CppFunction f(std::forward(raw_f)); + auto name_or_schema = detail::constructSchemaOrName(std::forward(raw_name_or_schema)); + return _def(std::move(name_or_schema), std::move(f)); + } + + // Register an implementation for an operator. You may register multiple + // implementations for a single operator at different dispatch keys + // (see torch::dispatch). Implementations must have a corresponding + // declaration (from def), otherwise they are invalid. + template + Library& impl(const char* name, Func&& raw_f) & { + CppFunction f(std::forward(raw_f)); + return _impl(name, std::move(f)); + } + + // Convenience overload for directly specifying the dispatch key. Dispatch + // can validly be either DeviceType or DispatchKey; check torch::dispatch for + // the canonical list of accepted overloads. + template + Library& impl(const char* name, Dispatch&& key, Func&& raw_f) & { + return impl(name, dispatch(std::forward(key), std::forward(raw_f))); + } + + // Convenience overload for unboxed only kernels. These are quite common + // but will be eventually eliminated; this function makes it easy to grep for + // them. + // + // TODO: Remove this overload once the makeUnboxedOnly incidence rate + // goes way down + template + Library& impl_UNBOXED(const char* name, Func* raw_f) & { + return impl(name, CppFunction::makeUnboxedOnly(raw_f)); + } + + // Register a fallback implementation for all operators which will be used + // if there is not a specific implementation for an operator available. + // Providing a DispatchKey is MANDATORY for fallback at the moment; e.g., + // only call this from TORCH_LIBRARY_IMPL + template + Library& fallback(Func&& raw_f) & { + CppFunction f((std::forward(raw_f))); + return _fallback(std::move(f)); + } + +private: + Kind kind_; + c10::optional ns_; + c10::optional dispatch_key_; + const char* file_; + uint32_t line_; + + std::vector registrars_; + + friend detail::TorchLibraryInit; + + // Non-user visible actual implementations of functions. These aren't + // public because we only implement & qualifier and not && qualifier + Library& _def(FunctionSchema&& schema, OperatorName* out_name = nullptr) &; + Library& _def(c10::either&&, CppFunction&& f) &; + Library& _impl(const char* name, CppFunction&& f) &; + Library& _fallback(CppFunction&& f) &; +}; + +namespace detail { + +class TorchLibraryInit final { +private: + using InitFn = void(Library&); + Library lib_; +public: + TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional k, const char* file, uint32_t line) + : lib_(kind, ns, k, file, line) { + fn(lib_); + } +}; + +} // namespace detail + } // namespace c10 +// NB: The EXACT NAMING of the initializer functions (e.g., +// TORCH_LIBRARY_init_aten) matters for the code analyzer; +// see the regexes at tools/code_analyzer/run_analyzer.sh + +#define TORCH_LIBRARY(ns, m) \ + static void TORCH_LIBRARY_init_ ## ns (c10::Library&); \ + static c10::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \ + c10::Library::DEF, \ + &TORCH_LIBRARY_init_ ## ns, \ + #ns, c10::nullopt, __FILE__, __LINE__ \ + ); \ + void TORCH_LIBRARY_init_ ## ns (c10::Library& m) + +// This macro is a version of TORCH_LIBRARY that doesn't enforce that there +// is only one library (it is a "fragment"). This should ONLY be used +// with PerOpRegistration (as its name suggests). +#define TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(ns, m) \ + static void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (c10::Library&); \ + static c10::detail::TorchLibraryInit TORCH_LIBRARY_FRAGMENT_static_init_ ## ns ## _ ## k ( \ + c10::Library::FRAGMENT, \ + &TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k, \ + #ns, c10::nullopt, __FILE__, __LINE__ \ + ); \ + void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (c10::Library& m) + +#define TORCH_LIBRARY_IMPL(ns, k, m) \ + static void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (c10::Library&); \ + static c10::detail::TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_ ## ns ## _ ## k ( \ + c10::Library::IMPL, \ + & TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k, \ + #ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__ \ + ); \ + void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (c10::Library& m) + +// These are variants of the macros above which are to be used for testing (they +// don't setup the static initializer, so you can control the visibility of +// the allocated library yourself). +// +// DO NOT use these in production code, they are NOT understood by the +// code analyzer and will be incorrectly analyzed in those situations. +#define MAKE_TORCH_LIBRARY(ns) Library(Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__) +#define MAKE_TORCH_LIBRARY_IMPL(ns, k) Library(Library::IMPL, #ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__) + namespace torch { // Old-style API using RegisterOperators = c10::RegisterOperators; + + // New-style API + using c10::dispatch; + using c10::schema; } diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 83986da6fde..7200fd29e1e 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -12,7 +12,6 @@ #include #include -#include #include #include @@ -22,10 +21,7 @@ using c10::OperatorHandle; using c10::Dispatcher; using c10::IValue; using c10::DispatchKey; - -using torch::Library; -using torch::CppFunction; - +using c10::Library; using at::Tensor; namespace { @@ -1446,7 +1442,7 @@ TEST(NewOperatorRegistrationTest, dispatchMultiple) { TEST(NewOperatorRegistrationTest, fallback) { auto m = MAKE_TORCH_LIBRARY_IMPL(_, CPU); - m.fallback(CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); + m.fallback(c10::CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()"); @@ -1499,9 +1495,9 @@ TEST(NewOperatorRegistrationTest, CppFunction) { m.def("fn2", dummy_fn); m.def("fn3", [](const Tensor& x) { return x; }); // These require explicit schema - m.def("fn4(Tensor x) -> Tensor", CppFunction::makeFallthrough()); - m.def("fn5(Tensor x) -> Tensor", CppFunction::makeUnboxedOnly(dummy_fn)); - m.def("fn6(Tensor x) -> Tensor", CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); + m.def("fn4(Tensor x) -> Tensor", c10::CppFunction::makeFallthrough()); + m.def("fn5(Tensor x) -> Tensor", c10::CppFunction::makeUnboxedOnly(dummy_fn)); + m.def("fn6(Tensor x) -> Tensor", c10::CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); } // Some internal tests that have to be done from C++ diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index c3205b11f38..8757f1e552f 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -120,7 +120,7 @@ m.def("${unqual_schema_string}"); # TORCH_LIBRARY macro invocation DEFAULT_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\ m.impl("${unqual_operator_name_with_overload}", - torch::CppFunction::makeUnboxedOnly(TypeDefault::${type_wrapper_name})); + CppFunction::makeUnboxedOnly(TypeDefault::${type_wrapper_name})); """) DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\ @@ -137,7 +137,7 @@ m.impl("${unqual_operator_name_with_overload}", &TypeDefault::${type_wrapper_nam BACKEND_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\ m.impl("${unqual_operator_name_with_overload}", torch::dispatch(DispatchKey::${Backend}, - torch::CppFunction::makeUnboxedOnly(${Type}::${type_wrapper_name})) + CppFunction::makeUnboxedOnly(${Type}::${type_wrapper_name})) ); """) diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index b9544c8f6b0..312805ba005 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -8,7 +8,7 @@ #include #include #include -#include +#include namespace { diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index 92e04616151..dc31d11d877 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include namespace at { namespace native { diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index 6546bf30a76..0c64487376c 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include namespace at { diff --git a/aten/src/ATen/native/cuda/Resize.cu b/aten/src/ATen/native/cuda/Resize.cu index 4afa320443d..75ae44533c1 100644 --- a/aten/src/ATen/native/cuda/Resize.cu +++ b/aten/src/ATen/native/cuda/Resize.cu @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index 7feef9c9dfe..a0fb36782ce 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -12,7 +12,7 @@ #include #include #include -#include +#include namespace at { namespace native { diff --git a/aten/src/ATen/native/quantized/README.md b/aten/src/ATen/native/quantized/README.md index 8364fb478ba..be8e3cb55b4 100644 --- a/aten/src/ATen/native/quantized/README.md +++ b/aten/src/ATen/native/quantized/README.md @@ -116,7 +116,7 @@ The final file `ATen/native/quantized/cpu/qxand.cpp` would look as follows #include #include // Need that for the `native_functions.yaml` #include -#include +#include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qadd.cpp b/aten/src/ATen/native/quantized/cpu/qadd.cpp index 84c7b992d41..56bd4a4f69d 100644 --- a/aten/src/ATen/native/quantized/cpu/qadd.cpp +++ b/aten/src/ATen/native/quantized/cpu/qadd.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp b/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp index f25dcbed612..1d9e5664a64 100644 --- a/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp +++ b/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qclamp.cpp b/aten/src/ATen/native/quantized/cpu/qclamp.cpp index c177da3c4a8..b3ec06e3680 100644 --- a/aten/src/ATen/native/quantized/cpu/qclamp.cpp +++ b/aten/src/ATen/native/quantized/cpu/qclamp.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qconcat.cpp b/aten/src/ATen/native/quantized/cpu/qconcat.cpp index 535a7a64575..eacc68ca7c9 100644 --- a/aten/src/ATen/native/quantized/cpu/qconcat.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconcat.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index d6368402ba2..90d85634e51 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index b900030a39d..50fe6841c54 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp index 435473adb4d..7e785bac414 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qelu.cpp b/aten/src/ATen/native/quantized/cpu/qelu.cpp index 24c264f2063..a44c3436c5a 100644 --- a/aten/src/ATen/native/quantized/cpu/qelu.cpp +++ b/aten/src/ATen/native/quantized/cpu/qelu.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp b/aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp index 7cbf5acbe80..c958f3d80e4 100644 --- a/aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp +++ b/aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qhardswish.cpp b/aten/src/ATen/native/quantized/cpu/qhardswish.cpp index defa141da19..48c12e0c779 100644 --- a/aten/src/ATen/native/quantized/cpu/qhardswish.cpp +++ b/aten/src/ATen/native/quantized/cpu/qhardswish.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index fdd398c9d7d..79ba998092c 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index 4f60907731d..9d838db598f 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index 2af0625f05d..3e593317880 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp index caeb4a328a2..9549f827330 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qmul.cpp b/aten/src/ATen/native/quantized/cpu/qmul.cpp index b4a527cc19c..c0b3f5d9119 100644 --- a/aten/src/ATen/native/quantized/cpu/qmul.cpp +++ b/aten/src/ATen/native/quantized/cpu/qmul.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qpool.cpp b/aten/src/ATen/native/quantized/cpu/qpool.cpp index e6f4c0425a1..4a635d5de8e 100644 --- a/aten/src/ATen/native/quantized/cpu/qpool.cpp +++ b/aten/src/ATen/native/quantized/cpu/qpool.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp index 801739520cf..3f5a3c2b2da 100644 --- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp +++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp b/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp index 8bcd72a6f4c..9d365ba1ce9 100644 --- a/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp +++ b/aten/src/ATen/native/quantized/cpu/qsigmoid.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qsort.cpp b/aten/src/ATen/native/quantized/cpu/qsort.cpp index 75e6f520f23..c30ed78c40d 100644 --- a/aten/src/ATen/native/quantized/cpu/qsort.cpp +++ b/aten/src/ATen/native/quantized/cpu/qsort.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/qtanh.cpp b/aten/src/ATen/native/quantized/cpu/qtanh.cpp index 4c420c023a9..f83cac428e8 100644 --- a/aten/src/ATen/native/quantized/cpu/qtanh.cpp +++ b/aten/src/ATen/native/quantized/cpu/qtanh.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/quantized/cpu/tensor_operators.cpp b/aten/src/ATen/native/quantized/cpu/tensor_operators.cpp index 3f75cb35da2..c25b06352e6 100644 --- a/aten/src/ATen/native/quantized/cpu/tensor_operators.cpp +++ b/aten/src/ATen/native/quantized/cpu/tensor_operators.cpp @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include namespace at { diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 97a8e674cef..cddad1ebab8 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -1,4 +1,4 @@ -#include +#include TORCH_LIBRARY(quantized, m) { m.def("add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc"); diff --git a/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp b/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp index d57a1ba9271..55f273accca 100644 --- a/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp +++ b/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp @@ -1,6 +1,6 @@ #ifdef USE_XNNPACK -#include +#include #include #include #include diff --git a/aten/src/ATen/templates/BackendSelectRegister.cpp b/aten/src/ATen/templates/BackendSelectRegister.cpp index b99d347dc74..0acde1b9f12 100644 --- a/aten/src/ATen/templates/BackendSelectRegister.cpp +++ b/aten/src/ATen/templates/BackendSelectRegister.cpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include namespace at { diff --git a/aten/src/ATen/templates/PerOpRegistration.cpp b/aten/src/ATen/templates/PerOpRegistration.cpp index 53ed0ceb3bc..ab3b09a46f9 100644 --- a/aten/src/ATen/templates/PerOpRegistration.cpp +++ b/aten/src/ATen/templates/PerOpRegistration.cpp @@ -1,7 +1,7 @@ // ${generated_comment} #include -#include +#include #include $extra_headers diff --git a/aten/src/ATen/templates/SchemaRegister.cpp b/aten/src/ATen/templates/SchemaRegister.cpp index f48e732f476..dad29f53393 100644 --- a/aten/src/ATen/templates/SchemaRegister.cpp +++ b/aten/src/ATen/templates/SchemaRegister.cpp @@ -1,7 +1,7 @@ // ${generated_comment} #include -#include +#include using namespace at; diff --git a/aten/src/ATen/templates/SparseTypeDerived.cpp b/aten/src/ATen/templates/SparseTypeDerived.cpp index 0035c1db379..8447658d177 100644 --- a/aten/src/ATen/templates/SparseTypeDerived.cpp +++ b/aten/src/ATen/templates/SparseTypeDerived.cpp @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include #include diff --git a/aten/src/ATen/templates/TypeDefault.cpp b/aten/src/ATen/templates/TypeDefault.cpp index e0518ac10e1..d701312561e 100644 --- a/aten/src/ATen/templates/TypeDefault.cpp +++ b/aten/src/ATen/templates/TypeDefault.cpp @@ -13,7 +13,7 @@ #include #include #include -#include +#include namespace { static const char* named_tensors_unsupported_error = diff --git a/aten/src/ATen/templates/TypeDerived.cpp b/aten/src/ATen/templates/TypeDerived.cpp index 658d20a587f..2834f7a5a29 100644 --- a/aten/src/ATen/templates/TypeDerived.cpp +++ b/aten/src/ATen/templates/TypeDerived.cpp @@ -27,7 +27,7 @@ $storage_tensor_headers #include #include -#include +#include $extra_cuda_headers $legacy_th_headers diff --git a/aten/src/ATen/test/backend_fallback_test.cpp b/aten/src/ATen/test/backend_fallback_test.cpp index 806efc49ec3..3904bf043a0 100644 --- a/aten/src/ATen/test/backend_fallback_test.cpp +++ b/aten/src/ATen/test/backend_fallback_test.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include using namespace at; @@ -110,7 +110,7 @@ void generic_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Stack* TEST(BackendFallbackTest, TestBackendFallbackWithMode) { auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode); - m.fallback(torch::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>()); + m.fallback(CppFunction::makeFromBoxedFunction<&generic_mode_fallback>()); c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode); @@ -122,7 +122,7 @@ TEST(BackendFallbackTest, TestBackendFallbackWithMode) { TEST(BackendFallbackTest, TestBackendFallbackWithWrapper) { auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericWrapper); - m.fallback(torch::CppFunction::makeFromBoxedFunction<&generic_wrapper_fallback>()); + m.fallback(CppFunction::makeFromBoxedFunction<&generic_wrapper_fallback>()); override_call_count = 0; Tensor a = at::detail::make_tensor(ones({5, 5}, kDouble)); @@ -132,10 +132,10 @@ TEST(BackendFallbackTest, TestBackendFallbackWithWrapper) { TEST(BackendFallbackTest, TestFallthroughBackendFallback) { auto m = MAKE_TORCH_LIBRARY_IMPL(aten, TESTING_ONLY_GenericMode); - m.impl("mul.Tensor", torch::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>()); + m.impl("mul.Tensor", c10::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>()); auto gm = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode); - gm.fallback(torch::CppFunction::makeFallthrough()); + gm.fallback(c10::CppFunction::makeFallthrough()); c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode); diff --git a/aten/src/ATen/test/cpu_rng_test.cpp b/aten/src/ATen/test/cpu_rng_test.cpp index 05f1804142d..79ee055681b 100644 --- a/aten/src/ATen/test/cpu_rng_test.cpp +++ b/aten/src/ATen/test/cpu_rng_test.cpp @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp index ff0c4f7fb12..822a1293bbb 100644 --- a/aten/src/ATen/test/extension_backend_test.cpp +++ b/aten/src/ATen/test/extension_backend_test.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include diff --git a/aten/src/ATen/test/rng_test.h b/aten/src/ATen/test/rng_test.h index 1f053024b2d..2f68fbec946 100644 --- a/aten/src/ATen/test/rng_test.h +++ b/aten/src/ATen/test/rng_test.h @@ -2,7 +2,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 063f425e7a9..ceb3f42bc19 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -699,12 +699,7 @@ endif() install(DIRECTORY "${TORCH_SRC_DIR}/csrc" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch FILES_MATCHING PATTERN "*.h") - install(FILES - "${TORCH_SRC_DIR}/script.h" - "${TORCH_SRC_DIR}/extension.h" - "${TORCH_SRC_DIR}/custom_class.h" - "${TORCH_SRC_DIR}/library.h" - "${TORCH_SRC_DIR}/custom_class_detail.h" + install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h" "${TORCH_SRC_DIR}/custom_class_detail.h" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch) diff --git a/docs/cpp/source/Doxyfile b/docs/cpp/source/Doxyfile index cb472bc1d87..a5720ee1e76 100644 --- a/docs/cpp/source/Doxyfile +++ b/docs/cpp/source/Doxyfile @@ -64,7 +64,6 @@ INPUT = ../../../aten/src/ATen/ATen.h \ ../../../torch/csrc/jit/runtime/custom_operator.h \ ../../../torch/csrc/jit/serialization/import.h \ ../../../torch/csrc/jit/api/module.h \ - ../../../torch/library.h \ ../../../torch/custom_class.h # Don't include .cpp files! FILE_PATTERNS = *.h diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp index ebecd7d04c4..fbd356576b8 100644 --- a/test/cpp/jit/test_custom_class.cpp +++ b/test/cpp/jit/test_custom_class.cpp @@ -66,71 +66,78 @@ struct PickleTester : torch::CustomClassHolder { std::vector vals; }; +static auto test = torch::class_("_TorchScriptTesting", "_Foo") + .def(torch::init()) + // .def(torch::init<>()) + .def("info", &Foo::info) + .def("increment", &Foo::increment) + .def("add", &Foo::add) + .def("combine", &Foo::combine); + +static auto testStack = + torch::class_>( + "_TorchScriptTesting", + "_StackString") + .def(torch::init>()) + .def("push", &MyStackClass::push) + .def("pop", &MyStackClass::pop) + .def("clone", &MyStackClass::clone) + .def("merge", &MyStackClass::merge) + .def_pickle( + [](const c10::intrusive_ptr>& self) { + return self->stack_; + }, + [](std::vector state) { // __setstate__ + return c10::make_intrusive>( + std::vector{"i", "was", "deserialized"}); + }) + .def("return_a_tuple", &MyStackClass::return_a_tuple) + .def( + "top", + [](const c10::intrusive_ptr>& self) + -> std::string { return self->stack_.back(); }); +// clang-format off + // The following will fail with a static assert telling you you have to + // take an intrusive_ptr as the first argument. + // .def("foo", [](int64_t a) -> int64_t{ return 3;}); +// clang-format on + +static auto testPickle = + torch::class_("_TorchScriptTesting", "_PickleTester") + .def(torch::init>()) + .def_pickle( + [](c10::intrusive_ptr self) { // __getstate__ + return std::vector{1, 3, 3, 7}; + }, + [](std::vector state) { // __setstate__ + return c10::make_intrusive(std::move(state)); + }) + .def( + "top", + [](const c10::intrusive_ptr& self) { + return self->vals.back(); + }) + .def("pop", [](const c10::intrusive_ptr& self) { + auto val = self->vals.back(); + self->vals.pop_back(); + return val; + }); + at::Tensor take_an_instance(const c10::intrusive_ptr& instance) { return torch::zeros({instance->vals.back(), 4}); } -TORCH_LIBRARY(_TorchScriptTesting, m) { - m.class_("_Foo") - .def(torch::init()) - // .def(torch::init<>()) - .def("info", &Foo::info) - .def("increment", &Foo::increment) - .def("add", &Foo::add) - .def("combine", &Foo::combine); - - m.class_>("_StackString") - .def(torch::init>()) - .def("push", &MyStackClass::push) - .def("pop", &MyStackClass::pop) - .def("clone", &MyStackClass::clone) - .def("merge", &MyStackClass::merge) - .def_pickle( - [](const c10::intrusive_ptr>& self) { - return self->stack_; - }, - [](std::vector state) { // __setstate__ - return c10::make_intrusive>( - std::vector{"i", "was", "deserialized"}); - }) - .def("return_a_tuple", &MyStackClass::return_a_tuple) - .def( - "top", - [](const c10::intrusive_ptr>& self) - -> std::string { return self->stack_.back(); }); - // clang-format off - // The following will fail with a static assert telling you you have to - // take an intrusive_ptr as the first argument. - // .def("foo", [](int64_t a) -> int64_t{ return 3;}); - // clang-format on - - m.class_("_PickleTester") - .def(torch::init>()) - .def_pickle( - [](c10::intrusive_ptr self) { // __getstate__ - return std::vector{1, 3, 3, 7}; - }, - [](std::vector state) { // __setstate__ - return c10::make_intrusive(std::move(state)); - }) - .def( - "top", - [](const c10::intrusive_ptr& self) { - return self->vals.back(); - }) - .def("pop", [](const c10::intrusive_ptr& self) { - auto val = self->vals.back(); - self->vals.pop_back(); - return val; - }); - - m.def( - "take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y", - take_an_instance); - // test that schema inference is ok too - m.def("take_an_instance_inferred", take_an_instance); +torch::RegisterOperators& register_take_instance() { + static auto instance_registry = torch::RegisterOperators().op( + torch::RegisterOperators::options() + .schema( + "_TorchScriptTesting::take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y") + .catchAllKernel()); + return instance_registry; } +static auto& ensure_take_instance_registered = register_take_instance(); + } // namespace void testTorchbindIValueAPI() { diff --git a/test/cpp_extensions/msnpu_extension.cpp b/test/cpp_extensions/msnpu_extension.cpp index dde613bde07..d5521cde9cb 100644 --- a/test/cpp_extensions/msnpu_extension.cpp +++ b/test/cpp_extensions/msnpu_extension.cpp @@ -1,5 +1,6 @@ #include -#include + +#include using namespace at; diff --git a/test/mobile/op_deps/simple_ops.cpp b/test/mobile/op_deps/simple_ops.cpp index f417d17d75f..6e596b2077d 100644 --- a/test/mobile/op_deps/simple_ops.cpp +++ b/test/mobile/op_deps/simple_ops.cpp @@ -75,7 +75,7 @@ namespace { // cares about the name TORCH_LIBRARY(_test, m) { m.def("AA(Tensor self) -> Tensor"); - m.impl("AA", torch::CppFunction::makeUnboxedOnly(AA_op)); + m.impl("AA", CppFunction::makeUnboxedOnly(AA_op)); m.def("BB(Tensor self) -> Tensor"); m.impl("BB", &BB_op); @@ -93,7 +93,7 @@ TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(_test, m) { TORCH_LIBRARY_IMPL(_test, CPU, m) { m.impl_UNBOXED("EE", EE_op); - m.impl("FF", torch::CppFunction::makeUnboxedOnly(FF_op)); + m.impl("FF", CppFunction::makeUnboxedOnly(FF_op)); m.impl("GG", [] (Tensor a) -> Tensor { return call_FF_op(a); diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 8cc05351c46..6360055cceb 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -1,7 +1,7 @@ #include "torch/csrc/autograd/VariableTypeUtils.h" #include -#include +#include // ${generated_comment} diff --git a/tools/code_analyzer/run_analyzer.sh b/tools/code_analyzer/run_analyzer.sh index d9e97070c72..64b6127c700 100755 --- a/tools/code_analyzer/run_analyzer.sh +++ b/tools/code_analyzer/run_analyzer.sh @@ -15,7 +15,7 @@ echo "Analyze: ${INPUT}" # to operate, so for safety we match a more expansive set. "${ANALYZER_BIN}" \ -op_schema_pattern="^(_aten|_prim|aten|quantized|profiler|_test)::[a-zA-Z0-9_.]+(\(.*)?$" \ - -op_register_pattern="c10::RegisterOperators::(op|checkSchemaAndRegisterOp_)|c10::Module::(_?def|_?impl|impl_UNBOXED)|torch::Library::(_?def|_?impl|_?impl_UNBOXED)" \ + -op_register_pattern="c10::RegisterOperators::(op|checkSchemaAndRegisterOp_)|c10::Module::(_?def|_?impl|impl_UNBOXED)|c10::Library::(_?def|_?impl|_?impl_UNBOXED)" \ -op_invoke_pattern="c10::Dispatcher::findSchema|callOp" \ -root_symbol_pattern="torch::jit::[^(]" \ -torch_library_init_pattern="^.*TORCH_LIBRARY_init_([^(]+)(\(.*)?$" \ diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index 7550468d651..4a5bb81c03c 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -1,9 +1,9 @@ #include "function.h" +#include #include #include #include #include -#include #include "interpreter.h" namespace torch { diff --git a/torch/csrc/jit/mobile/register_mobile_autograd.cpp b/torch/csrc/jit/mobile/register_mobile_autograd.cpp index 018cb7c2b20..4e1cabe16a1 100644 --- a/torch/csrc/jit/mobile/register_mobile_autograd.cpp +++ b/torch/csrc/jit/mobile/register_mobile_autograd.cpp @@ -1,8 +1,8 @@ #include #include +#include #include #include -#include using Stack = std::vector; using at::Scalar; @@ -104,12 +104,12 @@ void log_softmax_kernel(const c10::OperatorHandle& op, Stack* stack) { TORCH_LIBRARY_IMPL(_aten, Autograd, m) { m.impl("add.Scalar", torch::autograd::VariableType::add_Scalar); m.impl("mul.Tensor", torch::autograd::VariableType::mul_Tensor); - m.impl("conv2d", torch::CppFunction::makeFromBoxedFunction()); + m.impl("conv2d", CppFunction::makeFromBoxedFunction()); m.impl("dropout", VariableType::dropout); m.impl("feature_dropout", VariableType::feature_dropout); m.impl( "log_softmax.int", - torch::CppFunction::makeFromBoxedFunction()); + CppFunction::makeFromBoxedFunction()); m.impl( "max_pool2d", [](const Tensor& self, @@ -127,7 +127,7 @@ TORCH_LIBRARY_IMPL(_aten, Autograd, m) { ceil_mode); }); m.impl("relu", VariableType::relu); - m.impl("view", torch::CppFunction::makeFromBoxedFunction()); + m.impl("view", CppFunction::makeFromBoxedFunction()); m.impl("t", VariableType::t); m.impl("addmm", VariableType::addmm); } diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index fc587a0690c..366d4c69547 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include @@ -45,12 +45,12 @@ c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) { template -inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) { +inline c10::CppFunction dispatch_str(const char* key, Func&& raw_f) { auto mb_key = parseDispatchKey(key); if (mb_key) { - return torch::dispatch(*mb_key, std::forward(raw_f)); + return c10::dispatch(*mb_key, std::move(raw_f)); } else { - torch::CppFunction f(std::forward(raw_f)); + c10::CppFunction f(std::forward(raw_f)); return f; } } @@ -62,16 +62,16 @@ void initDispatchBindings(PyObject* module) { .def("schema", &c10::OperatorHandle::schema); // TODO: figure out how to do chaining - py::class_(m, "_DispatchModule") + py::class_(m, "_DispatchModule") .def("def_", [](py::object self, const char* schema, const char* alias) { - self.cast().def(torch::schema(schema, parseAliasAnalysisKind(alias))); + self.cast().def(torch::schema(schema, parseAliasAnalysisKind(alias))); return self; }, "", py::arg("schema"), py::arg("alias") = "") // Simulated "legacy" def where alias analysis kind is not set. // Ordinarily this can only be exercised from RegisterOperators() API // but I am not going to bind that here .def("def_legacy", [](py::object self, const char* schema) { - self.cast().def(torch::jit::parseSchema(schema)); + self.cast().def(torch::jit::parseSchema(schema)); return self; }, "", py::arg("schema")) // We can't conveniently turn Python functions into valid functions @@ -83,7 +83,7 @@ void initDispatchBindings(PyObject* module) { // Mangling scheme: args_rets. One character per. // t = Tensor .def("def_name_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) { - self.cast().def( + self.cast().def( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; @@ -94,7 +94,7 @@ void initDispatchBindings(PyObject* module) { py::arg("dispatch") = "", py::arg("debug") = "default_def_name_t_t") .def("def_schema_t_t", [](py::object self, const char* schema, const char* dispatch, const char* alias, const char* debug) { - self.cast().def( + self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias)), dispatch_str(dispatch, [](const at::Tensor& a) { return a; @@ -108,7 +108,7 @@ void initDispatchBindings(PyObject* module) { // TODO: maybe consider deduplicating the definitions here, it's getting // pretty long .def("impl_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) { - self.cast().impl( + self.cast().impl( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; @@ -119,7 +119,7 @@ void initDispatchBindings(PyObject* module) { py::arg("dispatch") = "", py::arg("debug") = "impl_t_t") .def("impl_tt_t", [](py::object self, const char* name, const char* dispatch, const char* debug) { - self.cast().impl( + self.cast().impl( name, dispatch_str(dispatch, [](const at::Tensor& a, const at::Tensor& b) { return a; @@ -133,9 +133,9 @@ void initDispatchBindings(PyObject* module) { // This is a wee bit dodgy right now, but the "underlying" API is much // easier to test than the high level (using TORCH_LIBRARY, e.g.) if (name.empty()) { - return std::make_unique(torch::Library::FRAGMENT, "_", c10::DispatchKey::CatchAll, "/dev/null", 0); + return std::make_unique(c10::Library::FRAGMENT, "_", c10::DispatchKey::CatchAll, "/dev/null", 0); } else { - return std::make_unique(torch::Library::FRAGMENT, name, c10::nullopt, "/dev/null", 0); + return std::make_unique(c10::Library::FRAGMENT, name, c10::nullopt, "/dev/null", 0); } }); diff --git a/torch/custom_class.h b/torch/custom_class.h index 8feda5a2067..54684b53f34 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -5,12 +5,12 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include #include @@ -270,14 +270,4 @@ using ::torch::class_; } // namespace jit -template -inline class_ Library::class_(const std::string& className) { - TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT, - "class_(\"", className, "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. " - "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. " - "(Error occurred at ", file_, ":", line_, ")"); - TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_); - return torch::class_(*ns_, className); -} - -} +} // namespace torch diff --git a/torch/library.h b/torch/library.h deleted file mode 100644 index 862c6d5ae97..00000000000 --- a/torch/library.h +++ /dev/null @@ -1,406 +0,0 @@ -#pragma once - -#include -#include -#include -#if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD) -#include -#endif - -// Just for inferFunctionSchemaFromFunctor -#include - -namespace torch { - -template -class class_; - -// A quick tour of a few usage examples: -// -// // Define a library whose operators live in the namespace 'aten'. -// // You must define all of the operators for this library in -// // this namespace. -// TORCH_LIBRARY(aten, m) { -// // Define a schema for an operator, but provide no implementation -// m.def("mul(Tensor self, Tensor other) -> Tensor"); -// -// // Define a operator with exactly one implementation for all backends. -// m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl); -// -// // Provide an implementation for a defined operator (you can -// // provide multiple; one per backend). We'll take care of calling -// // the correct implementation depending on if we get a CPU -// // tensor or a CUDA tensor -// m.impl("mul", torch::kCPU, &mul_cpu_impl); -// m.impl("mul", torch::kCUDA, &mul_cuda_impl); -// } -// -// // Define implementations for operators for a non-standard backend, -// // e.g., XLA (valid values are entries of DispatchKey). These -// // operator names are not namespaced; you can define implementations -// // for any namespace. -// TORCH_LIBRARY_IMPL(aten, XLA, m) { -// m.impl("mul", &mul_xla_impl); -// } - - -// Represents a C++ function that implements an operator. Most users won't -// interact directly with this class, except via error messages: the -// constructors this function define the set of permissible "function"-like -// things you can bind via the interface. -// -// This class erases the type of the passed in function, but durably records -// the type via an inferred schema for the function. -// -// TODO: This is morally the same thing as KernelRegistrationConfig, but it's -// opaque to the user. -class CAFFE2_API CppFunction final { -public: - // This overload accepts function pointers, e.g., CppFunction(&add_impl) - template - explicit CppFunction(Func* f, std::enable_if_t::value, std::nullptr_t> = nullptr) - : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)) - // TODO: Don't go through WrapRuntimeKernelFunctor - , schema_(c10::detail::inferFunctionSchemaFromFunctor>>()) - , debug_() - {} - - // This overload accepts lambdas, e.g., CppFunction([](const Tensor& self) { ... }) - template - explicit CppFunction(Lambda&& f, std::enable_if_t>::value, std::nullptr_t> = nullptr) - : func_(c10::KernelFunction::makeFromUnboxedLambda(std::forward(f))) - // TODO: Don't go through WrapRuntimeKernelFunctor - , schema_(c10::detail::inferFunctionSchemaFromFunctor>>()) - , debug_() - {} - - // This static factory lets you create CppFunctions that (1) don't have boxing - // wrappers (because we don't support it yet) and (2) don't have schema - // inference (because some ops don't support it). - // - // TODO: Eliminate the necessity for this function entirely. - template - static CppFunction makeUnboxedOnly(Func* f) { - return CppFunction( - c10::KernelFunction::makeFromUnboxedOnlyRuntimeFunction(f), - /* schema */ nullptr - ); - } - - // TODO: more user friendly API - static CppFunction makeFallthrough() { - return CppFunction( - c10::KernelFunction::makeFallthrough(), - /* schema */ nullptr - ); - } - - // TODO: more user friendly API - template - static CppFunction makeFromBoxedFunction() { - return CppFunction( - c10::KernelFunction::makeFromBoxedFunction(), - /* schema */ nullptr - ); - } - - CppFunction&& debug(std::string d) && { - debug_ = std::move(d); - return std::move(*this); - } - -private: - c10::optional dispatch_key_; - c10::KernelFunction func_; - std::unique_ptr schema_; - std::string debug_; - - // The "setter" for dispatch_key_ - template - friend CppFunction dispatch(c10::DispatchKey, Func&&); - - // The only class which actually pulls out values from CppFunction (does so - // destructively, felt too lazy to write accessors that I don't even - // want users to use) - friend class Library; - - CppFunction(c10::KernelFunction func, std::unique_ptr schema); -}; - -// Create a CppFunction which is associated with a specific dispatch key. -// CppFunctions that are tagged with a DispatchKey don't get invoked /unless/ -// the dispatcher determines that the DispatchKey is the best choice for -// a function -template -inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) { - CppFunction f(std::forward(raw_f)); - if (k == c10::DispatchKey::CatchAll) { - f.dispatch_key_ = c10::nullopt; - } else { - f.dispatch_key_ = k; - } - return f; -} - -// Convenience overload of dispatch which accepts DeviceType -template -inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) { - auto deviceTypeToDispatchKey = [](c10::DeviceType t){ - switch (t) { - // This list is synchronized with the k-constants in c10/core/DeviceType.h - case c10::DeviceType::CPU: - return c10::DispatchKey::CPU; - case c10::DeviceType::CUDA: - return c10::DispatchKey::CUDA; - case c10::DeviceType::XLA: - return c10::DispatchKey::XLA; - case c10::DeviceType::HIP: - return c10::DispatchKey::HIP; - case c10::DeviceType::MSNPU: - return c10::DispatchKey::MSNPU; - default: - TORCH_CHECK(false, - "Device type ", t, " cannot be overloaded at dispatch time, " - "please file a bug report explaining what you were trying to do."); - } - }; - return dispatch(deviceTypeToDispatchKey(type), std::forward(raw_f)); -} - -inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k) { - c10::FunctionSchema s = torch::jit::parseSchema(str); - s.setAliasAnalysis(k); - return s; -} -inline c10::FunctionSchema schema(const char* s) { - return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA); -} -inline c10::FunctionSchema&& schema(c10::FunctionSchema&& s) { return std::move(s); } - -namespace detail { - - inline c10::either constructSchemaOrName(c10::FunctionSchema&& s) { - return c10::make_right(std::move(s)); - } - inline c10::either constructSchemaOrName(c10::OperatorName&& n) { - return c10::make_left(std::move(n)); - } - inline c10::either constructSchemaOrName(const char* str) { - auto s = torch::jit::parseSchemaOrName(str); - if (s.is_right()) { - s.right().setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - } - return s; - } - - class TorchLibraryInit; - -} - -// This is the "handle" by which functions defined in TORCH_LIBRARY -// and TORCH_LIBRARY_IMPL can define operators and override implementations -// at certain backends. -// -// Conventionally, you get access to it using those two macros: -// -// TORCH_LIBRARY(torchvision, m) { -// // m is a torch::Library -// m.def("roi_align", ...); -// ... -// } -// -// TORCH_LIBRARY_IMPL(aten, XLA, m) { -// // m is a torch::Library -// m.impl("add", ...); -// ... -// } -// -// In some cases, you need to define something that applies to all namespaces, -// not just one namespace (usually a fallback). In that case, use the reserved -// namespace _, e.g., -// -// TORCH_LIBRARY_IMPL(_, XLA, m) { -// m.fallback(xla_fallback); -// } -// -class CAFFE2_API Library final { -public: - // Which type of macro produced this Library - enum Kind { - DEF, // from TORCH_LIBRARY (no qualifier) - IMPL, - FRAGMENT, - }; - - // Use TORCH_LIBRARY/TORCH_LIBRARY_IMPL instead of these constructors directly - Library(Kind kind, std::string ns, c10::optional k, const char* file, uint32_t line); - - Library(const Library&) = delete; - Library& operator=(const Library&) = delete; - Library(Library&&) = default; - Library& operator=(Library&&) = default; - - // Some notes about the API design here. We had the following constraints: - // - // - We need to support multiple "types" of arguments for schema and - // functions (e.g., unnamed lambda types, regular functions, const char*, - // fully instantiated schemas) - // - We don't want to write exponentially many overloads - // - We don't want to rely on implicit conversion to a common type, - // because the C++ compiler will only be willing to do a single - // implicit conversion (reducing the set of valid types which you - // can invoke with); also error messages are worse when an implicit - // conversion is not selected (as the compiler will not explain - // why it didn't select an implicit conversion; this is different - // from overloads where it will explain each candidate overload and - // why it didn't apply) - // - // To solve all of these constraints at the same time, we use a trick taken - // from the pybind11 library: template over the argument in the user visible - // API, and inside of the templated function explicitly call an overloaded - // function to resolve the argument to a real type. You get the good error - // messages from overloads, but at the same time you only need to write the - // overload for any given argument type once. - - // Declare an operator with a schema, but don't provide any implementations - // for it. You're expected to then provide implementations using the - // impl() method. - template - Library& def(Schema&& raw_schema) & { - c10::FunctionSchema s = schema(std::forward(raw_schema)); - return _def(std::move(s)); - } - - // Convenience method to define an operator for a schema and then register - // an implementation for it. def(n, f) is almost equivalent to def(n).impl(f), - // except that if n is not a schema, then the schema is inferred from the - // static type of f. - template - Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) & { - CppFunction f(std::forward(raw_f)); - auto name_or_schema = detail::constructSchemaOrName(std::forward(raw_name_or_schema)); - return _def(std::move(name_or_schema), std::move(f)); - } - - // Register an implementation for an operator. You may register multiple - // implementations for a single operator at different dispatch keys - // (see torch::dispatch). Implementations must have a corresponding - // declaration (from def), otherwise they are invalid. - template - Library& impl(const char* name, Func&& raw_f) & { - CppFunction f(std::forward(raw_f)); - return _impl(name, std::move(f)); - } - - // Convenience overload for directly specifying the dispatch key. Dispatch - // can validly be either DeviceType or DispatchKey; check torch::dispatch for - // the canonical list of accepted overloads. - template - Library& impl(const char* name, Dispatch&& key, Func&& raw_f) & { - return impl(name, dispatch(std::forward(key), std::forward(raw_f))); - } - - // Convenience overload for unboxed only kernels. These are quite common - // but will be eventually eliminated; this function makes it easy to grep for - // them. - // - // TODO: Remove this overload once the makeUnboxedOnly incidence rate - // goes way down - template - Library& impl_UNBOXED(const char* name, Func* raw_f) & { - return impl(name, CppFunction::makeUnboxedOnly(raw_f)); - } - - // Register a fallback implementation for all operators which will be used - // if there is not a specific implementation for an operator available. - // Providing a DispatchKey is MANDATORY for fallback at the moment; e.g., - // only call this from TORCH_LIBRARY_IMPL - template - Library& fallback(Func&& raw_f) & { - CppFunction f((std::forward(raw_f))); - return _fallback(std::move(f)); - } - - template - inline class_ class_(const std::string& className); - -private: - Kind kind_; - c10::optional ns_; - c10::optional dispatch_key_; - const char* file_; - uint32_t line_; - - std::vector registrars_; - - friend class detail::TorchLibraryInit; - - // Non-user visible actual implementations of functions. These aren't - // public because we only implement & qualifier and not && qualifier - Library& _def(c10::FunctionSchema&& schema, c10::OperatorName* out_name = nullptr) &; - Library& _def(c10::either&&, CppFunction&& f) &; - Library& _impl(const char* name, CppFunction&& f) &; - Library& _fallback(CppFunction&& f) &; -}; - -namespace detail { - -class TorchLibraryInit final { -private: - using InitFn = void(Library&); - Library lib_; -public: - TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional k, const char* file, uint32_t line) - : lib_(kind, ns, k, file, line) { - fn(lib_); - } -}; - -} // namespace detail - -} // namespace torch - -// NB: The EXACT NAMING of the initializer functions (e.g., -// TORCH_LIBRARY_init_aten) matters for the code analyzer; -// see the regexes at tools/code_analyzer/run_analyzer.sh - -#define TORCH_LIBRARY(ns, m) \ - static void TORCH_LIBRARY_init_ ## ns (torch::Library&); \ - static torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \ - torch::Library::DEF, \ - &TORCH_LIBRARY_init_ ## ns, \ - #ns, c10::nullopt, __FILE__, __LINE__ \ - ); \ - void TORCH_LIBRARY_init_ ## ns (torch::Library& m) - -// This macro is a version of TORCH_LIBRARY that doesn't enforce that there -// is only one library (it is a "fragment"). This should ONLY be used -// with PerOpRegistration (as its name suggests). -#define TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(ns, m) \ - static void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (torch::Library&); \ - static torch::detail::TorchLibraryInit TORCH_LIBRARY_FRAGMENT_static_init_ ## ns ## _ ## k ( \ - torch::Library::FRAGMENT, \ - &TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k, \ - #ns, c10::nullopt, __FILE__, __LINE__ \ - ); \ - void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (torch::Library& m) - -#define TORCH_LIBRARY_IMPL(ns, k, m) \ - static void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (torch::Library&); \ - static torch::detail::TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_ ## ns ## _ ## k ( \ - torch::Library::IMPL, \ - & TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k, \ - #ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__ \ - ); \ - void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (torch::Library& m) - -// These are variants of the macros above which are to be used for testing (they -// don't setup the static initializer, so you can control the visibility of -// the allocated library yourself). -// -// DO NOT use these in production code, they are NOT understood by the -// code analyzer and will be incorrectly analyzed in those situations. -#define MAKE_TORCH_LIBRARY(ns) torch::Library(torch::Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__) -#define MAKE_TORCH_LIBRARY_IMPL(ns, k) torch::Library(torch::Library::IMPL, #ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__) - -#include