#include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace py = pybind11; namespace torch { namespace impl { namespace dispatch { torch::Library::Kind parseKind(const std::string& k) { static std::unordered_map kind_map = { {"DEF", torch::Library::DEF}, {"IMPL", torch::Library::IMPL}, {"FRAGMENT", torch::Library::FRAGMENT}, }; auto it = kind_map.find(k); TORCH_CHECK(it != kind_map.end(), "could not parse ", k); return it->second; } c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) { static std::unordered_map key_map = { {"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE}, {"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA}, {"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION}, {"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default }; auto it = key_map.find(k); TORCH_CHECK(it != key_map.end(), "could not parse ", k); return it->second; } template inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) { auto mb_key = std::string(key) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(key)); if (mb_key) { return torch::dispatch(*mb_key, std::forward(raw_f)); } else { torch::CppFunction f(std::forward(raw_f)); return f; } } class PythonKernelHolder : public c10::OperatorKernel { c10::SafePyObject func_; public: PythonKernelHolder(py::object func) : func_(func.release().ptr(), getPyInterpreter()) {} void operator()( const c10::OperatorHandle& op, c10::DispatchKeySet keyset, torch::jit::Stack* stack) { auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); py::gil_scoped_acquire g; auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); auto obj = py::reinterpret_steal(PyObject_Call( func_.ptr(getPyInterpreter()), args_kwargs.first.ptr(), args_kwargs.second.ptr())); if (!obj) { throw python_error(); } pushPyOutToStack(op, stack, obj, "PythonKernelHolder"); } }; void initDispatchBindings(PyObject* module) { auto m = py::handle(module).cast(); py::class_(m, "_DispatchOperatorHandle") .def("schema", &c10::OperatorHandle::schema); // TODO: figure out how to do chaining py::class_(m, "_DispatchModule") .def( "def_", [](py::object self, const char* schema, const char* alias) { self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias))); return self; }, "", py::arg("schema"), py::arg("alias") = "") // Simulated "legacy" def where alias analysis kind is not set. // Ordinarily this can only be exercised from RegisterOperators() API // but I am not going to bind that here .def( "def_legacy", [](py::object self, const char* schema) { self.cast().def(torch::jit::parseSchema(schema)); return self; }, "", py::arg("schema")) // We can't conveniently turn Python functions into valid functions // in the dispatcher. So instead we provide a bunch of precanned // functions for testing purposes. You're NOT intended to actually // call these functions; they're just here so we can actually register // something // // Mangling scheme: args_rets. One character per. // t = Tensor .def( "def_name_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) { self.cast().def( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; }).debug(debug)); return self; }, "", py::arg("name"), py::arg("dispatch") = "", py::arg("debug") = "default_def_name_t_t") .def( "def_schema_t_t", [](py::object self, const char* schema, const char* dispatch, const char* alias, const char* debug) { self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias)), dispatch_str(dispatch, [](const at::Tensor& a) { return a; }).debug(debug)); return self; }, "", py::arg("name"), py::arg("dispatch") = "", py::arg("alias") = "", py::arg("debug") = "default_def_schema_t_t") // TODO: maybe consider deduplicating the definitions here, it's getting // pretty long .def( "impl_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) { self.cast().impl( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; }).debug(debug)); return self; }, "", py::arg("name"), py::arg("dispatch") = "", py::arg("debug") = "impl_t_t") .def( "impl_tt_t", [](py::object self, const char* name, const char* dispatch, const char* debug) { self.cast().impl( name, dispatch_str( dispatch, [](const at::Tensor& a, const at::Tensor& b) { return a; }) .debug(debug)); return self; }, "", py::arg("name"), py::arg("dispatch") = "", py::arg("debug") = "") .def( "impl", [](py::object self, const char* name, const char* dispatch, py::object func) { HANDLE_TH_ERRORS self.cast().impl( name, dispatch_str( dispatch, CppFunction::makeFromBoxedFunctor( std::make_unique( std::move(func))))); END_HANDLE_TH_ERRORS_PYBIND }, "", py::arg("name"), py::arg("dispatch"), py::arg("func")) .def( "define", [](py::object self, const char* schema, const char* alias_analysis) { self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias_analysis))); return torch::schema(schema, parseAliasAnalysisKind(alias_analysis)) .name(); }, "", py::arg("schema"), py::arg("alias_analysis") = "") .def( "fallback_fallthrough", [](py::object self, const char* dispatch) { self.cast().fallback( dispatch_str(dispatch, CppFunction::makeFallthrough())); return self; }, "", py::arg("dispatch") = ""); m.def( "_dispatch_library", [](const char* kind, std::string name, const char* dispatch, const char* file, uint32_t linenum) { HANDLE_TH_ERRORS return std::make_unique( parseKind(kind), std::move(name), std::string(dispatch) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch)), "/dev/null", // temporary workaround linenum); END_HANDLE_TH_ERRORS_PYBIND }, "", py::arg("kind"), py::arg("name"), py::arg("dispatch"), py::arg("file") = "/dev/null", py::arg("linenum") = 0); m.def("_dispatch_dump", [](const char* name) -> std::string { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); if (!op) { return ""; } else { return op->dumpState(); } }); m.def("_dispatch_dump_table", [](const char* name) -> std::string { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); if (!op) { return ""; } else { return op->dumpComputedTable(); } }); m.def("_dispatch_check_invariants", [](const char* name) { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); if (!op) { } else { return op->checkInvariants(); } }); m.def("_dispatch_check_all_invariants", []() { c10::Dispatcher::singleton().checkInvariants(); }); m.def("_dispatch_has_kernel", [](const char* name) -> bool { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); return static_cast(op); }); m.def( // Returns whether or not a direct kernel registration exists // for this pair. "_dispatch_has_kernel_for_dispatch_key", [](const char* name, c10::DispatchKey dispatch) -> bool { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); TORCH_CHECK(op, "operator ", name, " does not exist"); return op->hasKernelForDispatchKey(dispatch); }); m.def( "_dispatch_has_kernel_for_any_dispatch_key", [](const char* name, c10::DispatchKeySet ks) -> bool { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); TORCH_CHECK(op, "operator ", name, " does not exist"); return op->hasKernelForAnyDispatchKey(ks); }); m.def( // Returns whether or not there is an entry in the runtime computed // dispatch table, for this pair. For example, if // "op" has a `CompositeImplicitAutograd` kernel, Then // _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return // true for all backends that are part of the alias set for // CompositeImplicitAutograd. "_dispatch_has_computed_kernel_for_dispatch_key", [](const char* name, const char* dispatch) -> bool { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); TORCH_CHECK(op, "operator ", name, " does not exist"); return op->hasComputedKernelForDispatchKey( c10::parseDispatchKey(dispatch)); }); m.def("_dispatch_find_dangling_impls", []() -> std::vector { auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls(); std::vector states; states.reserve(danglingImpls.size()); for (auto& danglingImpl : danglingImpls) { states.push_back(danglingImpl.dumpState()); } return states; }); m.def( "_dispatch_tls_set_dispatch_key_excluded", [](c10::DispatchKey dispatch_key, bool desired_state) { c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state); }); m.def( "_dispatch_tls_is_dispatch_key_excluded", [](c10::DispatchKey dispatch_key) { return c10::impl::tls_is_dispatch_key_excluded(dispatch_key); }); m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) { return at::isTensorSubclassLike(tensor); }); m.def("_dispatch_key_name", [](c10::DispatchKey k) { return c10::toString(k); }); m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; }); m.def("_dispatch_num_backends", []() { return c10::num_backends; }); #define DEF_ONE(n) .value(#n, c10::DispatchKey::n) py::enum_(m, "DispatchKey") DEF_ONE(Undefined) DEF_ONE(CompositeExplicitAutogradNonFunctional) DEF_ONE(CompositeExplicitAutograd) DEF_ONE(CompositeImplicitAutogradNestedTensor) DEF_ONE(CompositeImplicitAutograd) DEF_ONE(AutogradOther) DEF_ONE(Autograd) DEF_ONE(BackendSelect) DEF_ONE(ADInplaceOrView) DEF_ONE(PythonTLSSnapshot) DEF_ONE(Python) #define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n) #define DEF_MULTIPLE(fullname, prefix) \ DEF_SINGLE(, fullname) \ DEF_SINGLE(, StartOf##fullname##Backends) \ C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \ DEF_SINGLE(, EndOf##fullname##Backends) C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE) #undef DEF_MULTIPLE #undef DEF_SINGLE ; py::class_(m, "DispatchKeySet") .def(py::init()) .def("__or__", &c10::DispatchKeySet::operator|) .def("__sub__", &c10::DispatchKeySet::operator-) .def("__and__", &c10::DispatchKeySet::operator&) .def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId) .def("has", &c10::DispatchKeySet::has) .def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); }); m.attr("_dispatch_autogradother_backends") = py::cast(c10::autogradother_backends); m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) { return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t); }); m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) { return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t); }); m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) { return c10::toString(keyset); }); m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) { return c10::getBackendKeySetFromAutograd(k); }); m.def("_dispatch_keys", [](const at::Tensor& tensor) { auto* impl = tensor.unsafeGetTensorImpl(); return impl->key_set(); }); m.def("_dispatch_tls_local_include_set", []() { return c10::impl::tls_local_dispatch_key_set().included_; }); m.def("_dispatch_tls_local_exclude_set", []() { return c10::impl::tls_local_dispatch_key_set().excluded_; }); m.def( "_dispatch_is_included_in_alias", [](c10::DispatchKey a, c10::DispatchKey b) { return c10::isIncludedInAlias(a, b); }); py::class_(m, "ExcludeDispatchKeyGuard") .def(py::init()); py::class_(m, "_AutoDispatchBelowAutograd") .def(py::init<>()); // Prints out the name of every operator that has a kernel registered to the // Dispatcher under [dispatch_key]. If no arguments are specified, it'll print // out the name of every operator that the Dispatcher knows of. This can be // useful to answer questions like "list all operators that do not have a CPU // kernel". m.def( "_dispatch_print_registrations_for_dispatch_key", [](const char* dispatch_key = "") { auto k = std::string(dispatch_key) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch_key)); auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); for (auto& op : op_names) { std::cout << op << std::endl; } }, py::arg("dispatch_key") = static_cast("")); m.def( "_dispatch_get_registrations_for_dispatch_key", [](const char* dispatch_key = "") { auto k = std::string(dispatch_key) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch_key)); auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); std::vector names; names.reserve(op_names.size()); for (auto& op : op_names) { names.push_back( op.name + (op.overload_name == "" ? "" : "." + op.overload_name)); } return names; }, py::arg("dispatch_key") = static_cast("")); m.def("_are_functorch_transforms_active", []() { auto include_set = c10::impl::tls_local_dispatch_key_set().included_; return ( include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) || include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode)); }); } } // namespace dispatch } // namespace impl } // namespace torch