#include #include #include #include #include #include #include #include namespace torch::jit { namespace { using OperatorMap = std::unordered_map>>; struct OperatorRegistry { private: std::mutex lock; OperatorMap operators; // list of operators whose schema have not yet been parsed, and must // be registered before any call to lookup an operator std::vector> to_register; // Those two maps are used to implement lookupByLiteral, which is needed for // the n->match(...) calls. Basically, every function schema is assigned a // unique string you can use to match it. However, parsing those strings or // comparing and hashing them character by character would be very slow, so we // use a trick here! Every string literal in your program is guaranteed to // have static storage duration and so its address won't change at runtime. // This allows us to memoize answers for every pointer, which is done by the // operators_by_sig_literal map. Still, this map is initially empty, and so we // still need to do the complete string matching at the first time, which is // implemented by performing a lookup in the operators_by_sig map. std::unordered_map> operators_by_sig; std::unordered_map> operators_by_sig_literal; // Remember all registered operator names to check that they aren't // registered a second time. Registering an op multiple times is // fragile because it might depend on static initialization order // which one is picked at runtime. #ifdef C10_MOBILE std::unordered_set registered_operator_names; #endif // XXX - caller must be holding lock void registerPendingOperators() { for (const auto& op : to_register) { Symbol sym = Symbol::fromQualString(op->schema().name()); operators[sym].push_back(op); operators_by_sig[canonicalSchemaString(op->schema())] = op; } to_register.clear(); } public: void registerOperator(Operator&& op) { std::lock_guard guard(lock); #ifdef C10_MOBILE TORCH_INTERNAL_ASSERT( 0 == registered_operator_names.count(op.schema().operator_name()), "Tried to register operator \"", op.schema(), "\" to JIT but the operator name was already registered before. Please add or change the overload name."); registered_operator_names.insert(op.schema().operator_name()); #endif to_register.push_back(std::make_shared(std::move(op))); } void deregisterOperator(const FunctionSchema& schema) { Symbol sym = Symbol::fromQualString(schema.name()); auto sig = canonicalSchemaString(schema); std::lock_guard guard(lock); #ifdef C10_MOBILE TORCH_INTERNAL_ASSERT( 1 == registered_operator_names.count(schema.operator_name()), "Tried to remove operator ", schema, " from JIT but it wasn't found."); registered_operator_names.erase(schema.operator_name()); #endif // Try removing from pending operators list first auto pending_it = to_register.begin(); while (pending_it != to_register.end() && (*pending_it)->schema() != schema) ++pending_it; if (pending_it != to_register.end()) { to_register.erase(pending_it); return; } // Remove operator from signature map auto sig_it = operators_by_sig.find(sig); if (sig_it == operators_by_sig.end()) { return; } operators_by_sig.erase(sig_it); // Remove operator from symbol map auto op_it = operators.find(sym); TORCH_CHECK( op_it != operators.end(), "operator with signature ", sig, " is missing from symbol registry"); auto& op_vec = op_it->second; auto it = op_vec.begin(); while (it != op_vec.end() && (*it)->schema() != schema) ++it; if (it != op_vec.end()) { op_vec.erase(it); } if (op_vec.empty()) { operators.erase(op_it); } } const std::shared_ptr& lookupByLiteral(const char* name) { std::lock_guard guard(lock); registerPendingOperators(); auto it = operators_by_sig_literal.find(name); if (it == operators_by_sig_literal.end()) { auto op_ptr_it = operators_by_sig.find(canonicalSchemaString(parseSchema(name))); // Handy debugging code that dumps all operators we know about on mismatch #if 0 if (op_ptr_it == operators_by_sig.end()) { for (auto & entry : operators_by_sig) { std::cout << entry.first << std::endl; } } #endif TORCH_CHECK( op_ptr_it != operators_by_sig.end(), "Couldn't find an operator for ", name, ". Do you have to update a set of hardcoded JIT ops?"); it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second); } return it->second; } const std::vector>& getOperators(Symbol name) { std::lock_guard guard(lock); registerPendingOperators(); static std::vector> empty; auto it = operators.find(name); if (it != operators.end()) return it->second; return empty; } std::vector findSimilarOperators(Symbol input_op) { std::lock_guard guard(lock); registerPendingOperators(); using EntryPair = std::pair; auto cmp = [](const EntryPair& lhs, const EntryPair& rhs) { return lhs.first > rhs.first; }; std::priority_queue, decltype(cmp)> rankings(cmp); static constexpr size_t MAX_EDIT_DIST = 2u; for (const auto& op : operators) { auto edit_dist = ComputeEditDistance( input_op.toQualString(), op.first.toQualString(), MAX_EDIT_DIST); if (edit_dist <= MAX_EDIT_DIST) { rankings.emplace(edit_dist, op.first); } } std::vector ret; while (!rankings.empty()) { ret.push_back(rankings.top().second); rankings.pop(); } return ret; } const std::vector> getAllOperators() { std::lock_guard guard(lock); registerPendingOperators(); std::vector> values; values.clear(); for (auto& kv : operators) { values.insert(values.end(), kv.second.begin(), kv.second.end()); } return values; } }; OperatorRegistry& getRegistry() { static OperatorRegistry r; return r; } bool printerHasSpecialCaseFor(Symbol sym) { using namespace at; // WARNING: by adding a value to this set, you are asserting // that you have also added special handling of this symbol to // the python_print.cpp. Not adding handling will cause import and export // of modules with this new operator to fail. This is only required // for operators without schema. Prefer registering your operator with // schema to editing this list here. These cases should only be things // that require special handling because they do not fit normal schema const static std::unordered_set handled = { prim::Constant, prim::Uninitialized, prim::fork, prim::awaitable, prim::ListConstruct, prim::DictConstruct, prim::ListUnpack, prim::Print, prim::PythonOp, prim::TupleConstruct, prim::TupleIndex, prim::TupleSlice, prim::TupleUnpack, prim::CreateObject, prim::GetAttr, prim::SetAttr, prim::CallFunction, prim::isinstance, prim::unchecked_cast, prim::tolist, prim::rpc_async, prim::rpc_sync, prim::rpc_remote}; // WARNING: by adding a value to this set, you are asserting that your // primitive is only ever added during optimization and does not need // to be correctly printed for export (a process that happens before // optimization passes run) const static std::unordered_set unneeded = { c10::onnx::Reshape, // only used in onnx c10::onnx::Shape, // only used in onnx prim::AutogradZero, // temporarily inserted by autograd prim::AutogradAnyNonZero, // temporarily inserted by autograd prim::AutogradAllNonZero, // temporarily inserted by autograd prim::AutogradAllZero, // temporarily inserted by autograd prim::AutogradAdd, // temporarily inserted by autograd prim::ConstantChunk, // optimization pass adds it prim::DifferentiableGraph, // optimization pass adds it, prim::FunctionalGraph, // optimization pass adds it, prim::ReductionSizes, // optimization pass (fuser) adds it prim::BroadcastSizes, // optimization pass (fuser) adds it prim::ChunkSizes, // optimization pass (fuser) adds it prim::Drop, // used in interpreter only prim::FusedConcat, // optimization pass adds it prim::FusionGroup, // optimization pass adds it prim::CudaFusionGroup, // optimization pass adds it prim::CudaFusionGuard, // optimization pass adds it prim::TensorExprGroup, // optimization pass adds it prim::TensorExprDynamicGroup, // optimization pass adds it prim::StaticSubgraph, // optimization pass adds it prim::ConstantMKLDNNTensor, // optimization pass adds it prim::BroadcastMKLDNNTensors, // optimization pass adds it prim::oneDNNFusionGroup, // optimization pass adds it prim::oneDNNFusionGuard, // optimization pass adds it prim::StaticRuntimeCopyOuts, // used in SR only prim::Load, // used in interpreter only prim::MMTreeReduce, // used as an optimization prim::MMBatchSide, // used as an optimization prim::Store, // used in interpreter only prim::profile, // used in interpreter only prim::profile_ivalue, // used in interpreter only prim::TypeCheck, // used in interpreter only prim::RequiresGradCheck, // used in interpreter only prim::FallbackGraph, // converted into prim::CallFunction }; // These namespaces are required to have Python printers unless // otherwise noted in unneeded. const static std::unordered_set required_namespaces = { c10::namespaces::prim, c10::namespaces::aten, c10::namespaces::onnx, }; return handled.count(sym) || unneeded.count(sym) || !required_namespaces.count(sym.ns()); } } // anonymous namespace bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { using namespace at; // WARNING: by adding a case to this list, you are asserting that you have // added a case for the unschematized node in AliasDb::analyze const static std::unordered_set handled = { prim::If, prim::Loop, prim::FusionGroup, prim::CudaFusionGroup, prim::oneDNNFusionGroup, prim::DifferentiableGraph, prim::TensorExprGroup, prim::TensorExprDynamicGroup, prim::StaticSubgraph, prim::FunctionalGraph, prim::Constant, prim::Uninitialized, prim::DictConstruct, prim::ListConstruct, prim::TupleConstruct, prim::AutogradZero, prim::FusedConcat, prim::GradOf, prim::MMTreeReduce, prim::MMBatchSide, prim::BroadcastSizes, prim::ChunkSizes, prim::Closure, prim::TupleUnpack, prim::TupleIndex, prim::TupleSlice, prim::ListUnpack, prim::PythonOp, prim::ConstantChunk, prim::BroadcastingChunk, prim::MKLDNNGroup, prim::ConstantMKLDNNTensor, prim::BroadcastMKLDNNTensors, prim::fork, prim::awaitable, prim::awaitable_nowait, prim::awaitable_wait, prim::CreateObject, prim::AutogradAdd, prim::GetAttr, prim::SetAttr, prim::profile, prim::profile_ivalue, prim::TypeCheck, prim::RequiresGradCheck, prim::Print, prim::CallFunction, prim::CallMethod, aten::wait, prim::isinstance, prim::unchecked_cast, prim::tolist, prim::rpc_async, prim::rpc_sync, prim::rpc_remote, prim::Enter, prim::Exit, prim::FallbackGraph, }; // Operators that should not be used by alias analysis const static std::unordered_set purposefully_not_handled = { prim::Load, prim::Store, prim::Drop, at::onnx::Reshape, at::onnx::Shape, prim::AutogradAdd, }; return handled.count(symbol) || purposefully_not_handled.count(symbol); } void registerOperator(Operator&& op) { if (op.schema().is_varret()) { Symbol s = Symbol::fromQualString(op.schema().name()); if (!printerHasSpecialCaseFor(s)) { AT_ERROR( "Missing special case in python printer for non-schematized" " operator ", op.schema().name(), ". File a bug to add a case for this operator.\n"); } if (aliasAnalysisHasSpecialCaseFor(s) && op.aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE) { AT_ERROR( "Conflict in special casing in alias analysis for non-schematized" " operator ", op.schema().name(), ". File a bug to add a case for this operator.\n"); } if (aliasAnalysisHasSpecialCaseFor(s) && op.aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA) { AT_ERROR( "The operator ", op.schema().name(), " is special cased and cannot use explicit alias analysis."); } } getRegistry().registerOperator(std::move(op)); } void deregisterOperator(const FunctionSchema& schema) { getRegistry().deregisterOperator(schema); } const std::vector> getAllOperators() { return getRegistry().getAllOperators(); } const std::vector>& getAllOperatorsFor(Symbol name) { return getRegistry().getOperators(name); } std::vector> getAllSortedOperatorsFor(Symbol name) { const auto& unsortedOps = getAllOperatorsFor(name); // Depending on the order of registration, aten or jit ops may be // registered first. This sorting is helpful in cases where // deterministic (i.e. not dependent on build config) behavior is // desired; e.g. torch.ops.aten.* uses this function, and tries to // find the "first" op that matches input args. Without the sorting, // the "first" op may change depending on registration order. std::vector> sortedOps; sortedOps.reserve(unsortedOps.size()); std::copy_if( unsortedOps.begin(), unsortedOps.end(), std::back_inserter(sortedOps), [](const std::shared_ptr& op) { return op->isC10Op(); }); std::copy_if( unsortedOps.begin(), unsortedOps.end(), std::back_inserter(sortedOps), [](const std::shared_ptr& op) { return !op->isC10Op(); }); return sortedOps; } std::shared_ptr findOperatorFor(const c10::OperatorName& full_name) { for (const auto& op : getRegistry().getOperators(Symbol::fromQualString(full_name.name))) { if (op->schema().overload_name() == full_name.overload_name) { return op; } } return nullptr; } std::vector findSimilarOperators(Symbol input_op) { return getRegistry().findSimilarOperators(input_op); } std::shared_ptr getOperatorForLiteral(const char* signature) { return getRegistry().lookupByLiteral(signature); } std::string canonicalSchemaString(const FunctionSchema& schema) { std::string out = schema.name(); out.push_back('('); bool seen_kwarg_only = false; for (const auto i : c10::irange(schema.arguments().size())) { if (i > 0) { out += ", "; } if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) { out += "*, "; seen_kwarg_only = true; } const auto& arg = schema.arguments()[i]; out += arg.type()->str(); out.push_back(' '); out += arg.name(); } out += ") -> "; if (schema.returns().size() == 1) { out += schema.returns().at(0).type()->str(); } else if (schema.returns().size() > 1) { out.push_back('('); for (const auto i : c10::irange(schema.returns().size())) { if (i > 0) { out += ", "; } out += schema.returns()[i].type()->str(); } out.push_back(')'); } return out; } } // namespace torch::jit