pytorch/torch/csrc/utils/python_dispatch.cpp
Edward Yang e29348f828 Switch to pybind11 style registration function API. (#36258)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36258

Previous we had a && chaining style API.  There are some downsides to
this API:

- It's easy to forget the 'static' qualifier in front, leading to
  subtle ODR bugs.
- It is not compatible with torchbind class_ definitions, as these
  need multiple levels of chaining.  So in practice people end
  up having to define multiple static initializers, one per class.
- It's not like pybind11.
- There's no way to conveniently get the file and line number of
  the registration, as there is no macro point in the API.
- The old API doesn't really encourage people to put all of their
  definitions for a library in one place, and to give a custom
  namespace for it.  Similarly, the old API wasn't very DRY, because
  you had to keep repeating the namespace/dispatch key you
  were writing implementations for.

The new API is modeled exactly off of the PYBIND11_MODULE macro:
you write:

```
TORCH_LIBRARY(aten, m) {
  m.def("aten::add(Tensor self, Tensor other) -> Tensor");
  ...
}
```

in a non-chaining fashion, and under the hood the macro expands to
define a function, and define a static initializer that allocates
c10::Library (previously called c10::Module, but we renamed it
to avoid confusion with the existing NN module concept), passes
it to your function, and then retains it for the rest of the lifetime
of the program.  Specification of the namespace is mandatory,
and in later commit I plan to make it a hard error to TORCH_LIBRARY
the same library name twice.

If you are specifying an implementation for an existing operator
(e.g., you're the XLA backend, or even if you're just putting
registrations for implementations at the implementation site),
you should use TORCH_LIBRARY_IMPL, which instead takes a backend
argument (instead of namespace) and can be used to specify an
implementation for a backend.  Unlike TORCH_LIBRARY, you can do
as many of these as you want for a backend.

This needs updates to the mobile code analyzer.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D20929257

Pulled By: ezyang

fbshipit-source-id: ba04d78492e8c93ae7190165fb936f6872896ada
2020-04-16 10:44:21 -07:00

165 lines
5.7 KiB
C++

#include <torch/csrc/utils/python_dispatch.h>
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/ATen.h>
#include <pybind11/operators.h>
#include <iostream>
namespace py = pybind11;
namespace torch {
namespace impl {
namespace dispatch {
c10::optional<c10::DispatchKey> parseDispatchKey(const std::string& k) {
static std::unordered_map<std::string, c10::DispatchKey> key_map = {
{"cpu", c10::DispatchKey::CPU},
{"cuda", c10::DispatchKey::CUDA},
{"xla", c10::DispatchKey::XLA},
{"autograd", c10::DispatchKey::Autograd},
{"", c10::DispatchKey::Undefined},
};
auto it = key_map.find(k);
TORCH_CHECK(it != key_map.end(), "could not parse ", k);
if (it->second == c10::DispatchKey::Undefined) {
return c10::nullopt;
} else {
return c10::make_optional(it->second);
}
}
c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
{"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
{"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA},
{"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION},
{"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default
};
auto it = key_map.find(k);
TORCH_CHECK(it != key_map.end(), "could not parse ", k);
return it->second;
}
template <typename Func>
inline c10::CppFunction dispatch_str(const char* key, Func&& raw_f) {
auto mb_key = parseDispatchKey(key);
if (mb_key) {
return c10::dispatch(*mb_key, std::move(raw_f));
} else {
c10::CppFunction f(std::forward<Func>(raw_f));
return f;
}
}
void initDispatchBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle")
.def("schema", &c10::OperatorHandle::schema);
// TODO: figure out how to do chaining
py::class_<c10::Library>(m, "_DispatchModule")
.def("def_", [](py::object self, const char* schema, const char* alias) {
self.cast<c10::Library&>().def(torch::schema(schema, parseAliasAnalysisKind(alias)));
return self;
}, "", py::arg("schema"), py::arg("alias") = "")
// Simulated "legacy" def where alias analysis kind is not set.
// Ordinarily this can only be exercised from RegisterOperators() API
// but I am not going to bind that here
.def("def_legacy", [](py::object self, const char* schema) {
self.cast<c10::Library&>().def(torch::jit::parseSchema(schema));
return self;
}, "", py::arg("schema"))
// We can't conveniently turn Python functions into valid functions
// in the dispatcher. So instead we provide a bunch of precanned
// functions for testing purposes. You're NOT intended to actually
// call these functions; they're just here so we can actually register
// something
//
// Mangling scheme: args_rets. One character per.
// t = Tensor
.def("def_name_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
self.cast<c10::Library&>().def(
name,
dispatch_str(dispatch, [](const at::Tensor& a) {
return a;
}).debug(debug)
);
return self;
}, "", py::arg("name"),
py::arg("dispatch") = "",
py::arg("debug") = "default_def_name_t_t")
.def("def_schema_t_t", [](py::object self, const char* schema, const char* dispatch, const char* alias, const char* debug) {
self.cast<c10::Library&>().def(
torch::schema(schema, parseAliasAnalysisKind(alias)),
dispatch_str(dispatch, [](const at::Tensor& a) {
return a;
}).debug(debug)
);
return self;
}, "", py::arg("name"),
py::arg("dispatch") = "",
py::arg("alias") = "",
py::arg("debug") = "default_def_schema_t_t")
// TODO: maybe consider deduplicating the definitions here, it's getting
// pretty long
.def("impl_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
self.cast<c10::Library&>().impl(
name,
dispatch_str(dispatch, [](const at::Tensor& a) {
return a;
}).debug(debug)
);
return self;
}, "", py::arg("name"),
py::arg("dispatch") = "",
py::arg("debug") = "impl_t_t")
.def("impl_tt_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
self.cast<c10::Library&>().impl(
name,
dispatch_str(dispatch, [](const at::Tensor& a, const at::Tensor& b) {
return a;
}).debug(debug)
);
return self;
}, "", py::arg("name"), py::arg("dispatch") = "", py::arg("debug") = "")
;
m.def("_dispatch_import", [](std::string name) {
// This is a wee bit dodgy right now, but the "underlying" API is much
// easier to test than the high level (using TORCH_LIBRARY, e.g.)
if (name.empty()) {
return std::make_unique<c10::Library>(c10::Library::FRAGMENT, "_", c10::DispatchKey::CatchAll, "/dev/null", 0);
} else {
return std::make_unique<c10::Library>(c10::Library::FRAGMENT, name, c10::nullopt, "/dev/null", 0);
}
});
m.def("_dispatch_dump", [](const char* name) -> std::string {
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
if (!op) {
return "";
} else {
return op->dumpState();
}
});
m.def("_dispatch_check_invariants", [](const char* name) {
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
if (!op) {
} else {
return op->checkInvariants();
}
});
m.def("_dispatch_check_all_invariants", []() {
c10::Dispatcher::singleton().checkInvariants();
});
}
}}} // namespace torch::impl::dispatch