mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
**Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates: (1) TLS changes in `TorchDispatchModeTLS.h/cpp`. I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack). `TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes. `TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes". `TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes). There were also some mild codegen changes to support the new enum (2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py` The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`. I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly. (2) dispatching logic in `python_arg_parser.cpp` This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now: ``` (a) dispatch_on_mode() # try user modes first (where the mode stack automatically considers infra modes last) (b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses) (c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses) ``` Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass. How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key. (3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends). ----- original description below ----- The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that. **(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode** This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes. I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set. This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler. **(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()** This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs. **(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's** Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active. **(4) Allow traceable tensor subclasses to access storages from python** This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable" **(5) Fixed subclass fakeification** @wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct. **(6) make tensor subclasses resizeable** Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`. **(7) Added a basic DoubleTensor subclass for testing** I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482 Approved by: https://github.com/ezyang ghstack dependencies: #107415
777 lines
27 KiB
C++
777 lines
27 KiB
C++
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
|
#include <torch/csrc/utils/python_dispatch.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/FuncTorchTLS.h>
|
|
#include <ATen/FunctionalTensorWrapper.h>
|
|
#include <ATen/TensorSubclassLikeUtils.h>
|
|
#include <ATen/core/PythonOpRegistrationTrampoline.h>
|
|
#include <ATen/core/dispatch/Dispatcher.h>
|
|
#include <torch/library.h>
|
|
|
|
#include <c10/core/SafePyObject.h>
|
|
#include <torch/csrc/PyInterpreter.h>
|
|
#include <torch/csrc/autograd/python_variable.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
|
|
#include <c10/core/SingletonSymNodeImpl.h>
|
|
#include <c10/util/flat_hash_map.h>
|
|
#include <pybind11/operators.h>
|
|
#include <pybind11/stl.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/csrc/utils/python_raii.h>
|
|
|
|
#include <iostream>
|
|
#include <utility>
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace torch {
|
|
namespace impl {
|
|
namespace dispatch {
|
|
|
|
// NB: I'd like to index this on OperatorHandle, but I can't, as I can't
|
|
// guarantee that the main interpreter has finish doing all registrations before
|
|
// the other interpreters start banging on it
|
|
static ska::flat_hash_map<
|
|
c10::OperatorName,
|
|
ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
|
|
python_registrations_;
|
|
|
|
static torch::Library::Kind parseKind(const std::string& k) {
|
|
static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
|
|
{"DEF", torch::Library::DEF},
|
|
{"IMPL", torch::Library::IMPL},
|
|
{"FRAGMENT", torch::Library::FRAGMENT},
|
|
};
|
|
auto it = kind_map.find(k);
|
|
TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
|
|
return it->second;
|
|
}
|
|
static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
|
|
static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
|
|
{"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
|
|
{"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA},
|
|
{"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION},
|
|
{"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default
|
|
};
|
|
auto it = key_map.find(k);
|
|
TORCH_CHECK(it != key_map.end(), "could not parse ", k);
|
|
return it->second;
|
|
}
|
|
|
|
template <typename Func>
|
|
inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
|
|
auto mb_key = std::string(key).empty()
|
|
? c10::nullopt
|
|
: c10::make_optional(c10::parseDispatchKey(key));
|
|
if (mb_key) {
|
|
return torch::dispatch(*mb_key, std::forward<Func>(raw_f));
|
|
} else {
|
|
torch::CppFunction f(std::forward<Func>(raw_f));
|
|
return f;
|
|
}
|
|
}
|
|
|
|
struct EnableHermeticPyObject {
|
|
EnableHermeticPyObject()
|
|
: old_(c10::impl::HermeticPyObjectTLS::get_state()),
|
|
old_excluded_python_(
|
|
c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)),
|
|
old_python_(
|
|
c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)),
|
|
old_python_snapshot_(c10::impl::tls_is_dispatch_key_included(
|
|
at::DispatchKey::PythonTLSSnapshot)) {
|
|
c10::impl::HermeticPyObjectTLS::set_state(true);
|
|
c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true);
|
|
c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
at::DispatchKey::PythonTLSSnapshot, false);
|
|
}
|
|
~EnableHermeticPyObject() {
|
|
c10::impl::HermeticPyObjectTLS::set_state(old_);
|
|
c10::impl::tls_set_dispatch_key_excluded(
|
|
at::DispatchKey::Python, old_excluded_python_);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
at::DispatchKey::Python, old_python_);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_);
|
|
}
|
|
bool old_;
|
|
bool old_excluded_python_;
|
|
bool old_python_;
|
|
bool old_python_snapshot_;
|
|
};
|
|
|
|
class PythonKernelHolder : public c10::OperatorKernel {
|
|
c10::SafePyObject func_;
|
|
c10::DispatchKey dispatch_key_;
|
|
|
|
public:
|
|
PythonKernelHolder(py::object func, c10::DispatchKey dispatch_key)
|
|
: func_(func.release().ptr(), getPyInterpreter()),
|
|
dispatch_key_(dispatch_key) {}
|
|
|
|
void operator()(
|
|
const c10::OperatorHandle& op,
|
|
c10::DispatchKeySet keyset,
|
|
torch::jit::Stack* stack) {
|
|
// Figure out if we can handle it hermetically, or if we have
|
|
// to double dispatch
|
|
|
|
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
|
|
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
|
|
if (mode_stack_len > 0) {
|
|
const auto& cur_torch_dispatch_mode_state =
|
|
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
|
|
cur_torch_dispatch_mode_state->pyinterpreter()
|
|
->python_op_registration_trampoline(op, dispatch_key_, stack);
|
|
return;
|
|
}
|
|
|
|
const auto& schema = op.schema();
|
|
const auto num_arguments = schema.arguments().size();
|
|
|
|
// Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which
|
|
// means it's a nontrivial tensor subclass)
|
|
for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
|
|
if (ivalue.isTensor()) {
|
|
auto* interpreter =
|
|
ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
|
|
if (interpreter &&
|
|
ivalue.unsafeToTensorImpl()->key_set().has(
|
|
at::DispatchKey::Python)) {
|
|
(*interpreter)
|
|
->python_op_registration_trampoline(op, dispatch_key_, stack);
|
|
return;
|
|
}
|
|
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
|
|
// NB: use toListRef as it doesn't induce refcount bumps
|
|
// (toTensorListRef is not a thing)
|
|
for (const auto& nv : ivalue.toListRef()) {
|
|
if (nv.isNone()) {
|
|
continue;
|
|
}
|
|
auto* interpreter =
|
|
nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
|
|
if (interpreter &&
|
|
nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
|
|
(*interpreter)
|
|
->python_op_registration_trampoline(op, dispatch_key_, stack);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Nothing requires the operator to be homed to a specific interpreter, so
|
|
// run it on the current interpreter
|
|
|
|
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
|
py::gil_scoped_acquire g;
|
|
EnableHermeticPyObject g2;
|
|
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
|
auto obj = py::reinterpret_steal<py::object>(PyObject_Call(
|
|
func_.ptr(getPyInterpreter()),
|
|
args_kwargs.first.ptr(),
|
|
args_kwargs.second.ptr()));
|
|
if (!obj) {
|
|
throw python_error();
|
|
}
|
|
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
|
|
}
|
|
};
|
|
|
|
static torch::_RegisterOrVerify register_or_verify() {
|
|
if (isMainPyInterpreter()) {
|
|
return torch::_RegisterOrVerify::REGISTER;
|
|
} else {
|
|
return torch::_RegisterOrVerify::VERIFY;
|
|
}
|
|
}
|
|
|
|
static py::object ophandle_call_boxed(
|
|
const c10::OperatorHandle& handle,
|
|
py::args args,
|
|
const py::kwargs& kwargs) {
|
|
auto stack = torch::jit::createStackForSchema(
|
|
handle.schema(),
|
|
std::move(args),
|
|
kwargs,
|
|
/*self=*/c10::nullopt);
|
|
{
|
|
pybind11::gil_scoped_release no_gil_guard;
|
|
handle.callBoxed(stack);
|
|
}
|
|
return torch::jit::createPyObjectForStack(std::move(stack));
|
|
}
|
|
|
|
void initDispatchBindings(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle")
|
|
.def("schema", &c10::OperatorHandle::schema);
|
|
|
|
m.def("_dispatch_call_boxed", &ophandle_call_boxed);
|
|
|
|
// TODO: figure out how to do chaining
|
|
py::class_<torch::Library>(m, "_DispatchModule")
|
|
// Some of these APIs are only for testing and do not work in multipy
|
|
// environment
|
|
.def(
|
|
"def_",
|
|
[](py::object self, const char* schema, const char* alias) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().def(
|
|
torch::schema(schema, parseAliasAnalysisKind(alias)));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("schema"),
|
|
py::arg("alias") = "")
|
|
// Simulated "legacy" def where alias analysis kind is not set.
|
|
// Ordinarily this can only be exercised from RegisterOperators() API
|
|
// but I am not going to bind that here
|
|
.def(
|
|
"def_legacy",
|
|
[](py::object self, const char* schema) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("schema"))
|
|
// We can't conveniently turn Python functions into valid functions
|
|
// in the dispatcher. So instead we provide a bunch of precanned
|
|
// functions for testing purposes. You're NOT intended to actually
|
|
// call these functions; they're just here so we can actually register
|
|
// something
|
|
//
|
|
// Mangling scheme: args_rets. One character per.
|
|
// t = Tensor
|
|
.def(
|
|
"def_name_t_t",
|
|
[](py::object self,
|
|
const char* name,
|
|
const char* dispatch,
|
|
const char* debug) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().def(
|
|
name, dispatch_str(dispatch, [](const at::Tensor& a) {
|
|
return a;
|
|
}).debug(debug));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("name"),
|
|
py::arg("dispatch") = "",
|
|
py::arg("debug") = "default_def_name_t_t")
|
|
.def(
|
|
"def_schema_t_t",
|
|
[](py::object self,
|
|
const char* schema,
|
|
const char* dispatch,
|
|
const char* alias,
|
|
const char* debug) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().def(
|
|
torch::schema(schema, parseAliasAnalysisKind(alias)),
|
|
dispatch_str(dispatch, [](const at::Tensor& a) {
|
|
return a;
|
|
}).debug(debug));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("name"),
|
|
py::arg("dispatch") = "",
|
|
py::arg("alias") = "",
|
|
py::arg("debug") = "default_def_schema_t_t")
|
|
// TODO: maybe consider deduplicating the definitions here, it's getting
|
|
// pretty long
|
|
.def(
|
|
"impl_t_t",
|
|
[](py::object self,
|
|
const char* name,
|
|
const char* dispatch,
|
|
const char* debug) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().impl(
|
|
name, dispatch_str(dispatch, [](const at::Tensor& a) {
|
|
return a;
|
|
}).debug(debug));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("name"),
|
|
py::arg("dispatch") = "",
|
|
py::arg("debug") = "impl_t_t")
|
|
.def(
|
|
"impl",
|
|
[](const py::object& self,
|
|
const char* name,
|
|
// TODO: empty string no longer works
|
|
c10::DispatchKey dispatch,
|
|
py::object func) {
|
|
HANDLE_TH_ERRORS
|
|
auto& lib = self.cast<torch::Library&>();
|
|
if (func.is(py::module::import("torch.library")
|
|
.attr("fallthrough_kernel"))) {
|
|
lib.impl(
|
|
name,
|
|
torch::dispatch(dispatch, CppFunction::makeFallthrough()),
|
|
register_or_verify());
|
|
} else {
|
|
lib.impl(
|
|
name,
|
|
torch::dispatch(
|
|
dispatch,
|
|
CppFunction::makeFromBoxedFunctor(
|
|
std::make_unique<PythonKernelHolder>(
|
|
func, dispatch))),
|
|
register_or_verify());
|
|
python_registrations_[lib._resolve(name)].insert_or_assign(
|
|
dispatch,
|
|
std::make_shared<c10::SafePyObject>(
|
|
func.release().ptr(), getPyInterpreter()));
|
|
}
|
|
END_HANDLE_TH_ERRORS_PYBIND
|
|
},
|
|
"",
|
|
py::arg("name"),
|
|
py::arg("dispatch"),
|
|
py::arg("func"))
|
|
.def(
|
|
"define",
|
|
[](const py::object& self,
|
|
const char* schema,
|
|
const char* alias_analysis) {
|
|
auto parsed_schema =
|
|
torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
|
|
self.cast<torch::Library&>().def(
|
|
std::move(parsed_schema), {}, register_or_verify());
|
|
// TODO: this is dumb, had to make a second copy
|
|
return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
|
|
.name();
|
|
},
|
|
"",
|
|
py::arg("schema"),
|
|
py::arg("alias_analysis") = "")
|
|
.def(
|
|
"fallback_fallthrough",
|
|
[](py::object self, const char* dispatch) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().fallback(
|
|
dispatch_str(dispatch, CppFunction::makeFallthrough()));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("dispatch") = "");
|
|
|
|
m.def(
|
|
"_dispatch_library",
|
|
[](const char* kind,
|
|
std::string name,
|
|
const char* dispatch,
|
|
const char* file,
|
|
uint32_t linenum) {
|
|
HANDLE_TH_ERRORS
|
|
return std::make_unique<torch::Library>(
|
|
parseKind(kind),
|
|
std::move(name),
|
|
std::string(dispatch).empty()
|
|
? c10::nullopt
|
|
: c10::make_optional(c10::parseDispatchKey(dispatch)),
|
|
"/dev/null", // temporary workaround
|
|
linenum);
|
|
END_HANDLE_TH_ERRORS_PYBIND
|
|
},
|
|
"",
|
|
py::arg("kind"),
|
|
py::arg("name"),
|
|
py::arg("dispatch"),
|
|
py::arg("file") = "/dev/null",
|
|
py::arg("linenum") = 0);
|
|
|
|
m.def(
|
|
"_dispatch_find_schema_or_throw",
|
|
[](const char* name, const char* overload_name) -> c10::OperatorHandle {
|
|
return c10::Dispatcher::singleton().findSchemaOrThrow(
|
|
name, overload_name);
|
|
});
|
|
|
|
m.def("_dispatch_dump", [](const char* name) -> std::string {
|
|
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
if (!op) {
|
|
return "";
|
|
} else {
|
|
return op->dumpState();
|
|
}
|
|
});
|
|
|
|
m.def("_dispatch_dump_table", [](const char* name) -> std::string {
|
|
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
if (!op) {
|
|
return "";
|
|
} else {
|
|
return op->dumpComputedTable();
|
|
}
|
|
});
|
|
|
|
m.def("_dispatch_check_invariants", [](const char* name) {
|
|
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
if (!op) {
|
|
} else {
|
|
return op->checkInvariants();
|
|
}
|
|
});
|
|
|
|
m.def("_dispatch_check_all_invariants", []() {
|
|
c10::Dispatcher::singleton().checkInvariants();
|
|
});
|
|
|
|
m.def("_dispatch_has_kernel", [](const char* name) -> bool {
|
|
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
return static_cast<bool>(op);
|
|
});
|
|
|
|
m.def(
|
|
// Returns whether or not a direct kernel registration exists
|
|
// for this <op_name, dispatch_key> pair.
|
|
"_dispatch_has_kernel_for_dispatch_key",
|
|
[](const char* name, c10::DispatchKey dispatch) -> bool {
|
|
auto op =
|
|
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
TORCH_CHECK(op, "operator ", name, " does not exist");
|
|
return op->hasKernelForDispatchKey(dispatch);
|
|
});
|
|
|
|
m.def(
|
|
"_dispatch_has_kernel_for_any_dispatch_key",
|
|
[](const char* name, c10::DispatchKeySet ks) -> bool {
|
|
auto op =
|
|
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
TORCH_CHECK(op, "operator ", name, " does not exist");
|
|
return op->hasKernelForAnyDispatchKey(ks);
|
|
});
|
|
|
|
m.def(
|
|
// Returns whether or not there is an entry in the runtime computed
|
|
// dispatch table, for this <op_name, dispatch_key> pair. For example, if
|
|
// "op" has a `CompositeImplicitAutograd` kernel, Then
|
|
// _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
|
|
// true for all backends that are part of the alias set for
|
|
// CompositeImplicitAutograd.
|
|
"_dispatch_has_computed_kernel_for_dispatch_key",
|
|
[](const char* name, const char* dispatch) -> bool {
|
|
auto op =
|
|
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
TORCH_CHECK(op, "operator ", name, " does not exist");
|
|
return op->hasComputedKernelForDispatchKey(
|
|
c10::parseDispatchKey(dispatch));
|
|
});
|
|
|
|
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
|
|
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
|
|
|
|
std::vector<std::string> states;
|
|
states.reserve(danglingImpls.size());
|
|
for (auto& danglingImpl : danglingImpls) {
|
|
states.emplace_back(danglingImpl.dumpState());
|
|
}
|
|
|
|
return states;
|
|
});
|
|
|
|
m.def("_dispatch_get_all_op_names", []() -> std::vector<std::string> {
|
|
auto op_names = c10::Dispatcher::singleton().getAllOpNames();
|
|
|
|
std::vector<std::string> names;
|
|
names.reserve(op_names.size());
|
|
for (auto& op : op_names) {
|
|
std::stringstream ss;
|
|
ss << op.name;
|
|
if (!op.overload_name.empty()) {
|
|
ss << "." << op.overload_name;
|
|
}
|
|
names.emplace_back(ss.str());
|
|
}
|
|
|
|
return names;
|
|
});
|
|
|
|
m.def(
|
|
"_dispatch_tls_set_dispatch_key_excluded",
|
|
[](c10::DispatchKey dispatch_key, bool desired_state) {
|
|
c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state);
|
|
});
|
|
m.def(
|
|
"_dispatch_tls_is_dispatch_key_excluded",
|
|
[](c10::DispatchKey dispatch_key) {
|
|
return c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
|
|
});
|
|
m.def(
|
|
"_dispatch_tls_set_dispatch_key_included",
|
|
[](c10::DispatchKey dispatch_key, bool desired_state) {
|
|
c10::impl::tls_set_dispatch_key_included(dispatch_key, desired_state);
|
|
});
|
|
m.def(
|
|
"_dispatch_tls_is_dispatch_key_included",
|
|
[](c10::DispatchKey dispatch_key) {
|
|
return c10::impl::tls_is_dispatch_key_included(dispatch_key);
|
|
});
|
|
|
|
m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) {
|
|
return at::isTensorSubclassLike(tensor);
|
|
});
|
|
|
|
m.def("_dispatch_key_name", [](c10::DispatchKey k) {
|
|
return c10::toString(k);
|
|
});
|
|
m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; });
|
|
m.def("_to_functionality_key", [](c10::DispatchKey k) {
|
|
return c10::toFunctionalityKey(k);
|
|
});
|
|
// E.g. given `DispatchKey::AutogradFunctionality`, returns a keyset of:
|
|
// AutogradCPU
|
|
// AutogradCUDA
|
|
// ...
|
|
// AutogradPrivateUse3
|
|
m.def("_functionality_to_backend_keys", [](c10::DispatchKey key) {
|
|
std::vector<c10::DispatchKey> keys;
|
|
if (c10::isPerBackendFunctionalityKey(key)) {
|
|
auto ks = c10::DispatchKeySet(key) |
|
|
c10::DispatchKeySet(c10::DispatchKeySet::RAW, c10::full_backend_mask);
|
|
for (auto k : ks) {
|
|
keys.push_back(k);
|
|
}
|
|
} else {
|
|
keys.push_back(key);
|
|
}
|
|
return keys;
|
|
});
|
|
m.def("_dispatch_num_backends", []() { return c10::num_backends; });
|
|
|
|
#define DEF_ONE(n) .value(#n, c10::DispatchKey::n)
|
|
|
|
py::enum_<c10::DispatchKey>(m, "DispatchKey")
|
|
// clang-format off
|
|
DEF_ONE(Undefined)
|
|
DEF_ONE(CompositeExplicitAutogradNonFunctional)
|
|
DEF_ONE(CompositeExplicitAutograd)
|
|
DEF_ONE(CompositeImplicitAutogradNestedTensor)
|
|
DEF_ONE(CompositeImplicitAutograd)
|
|
DEF_ONE(AutogradOther)
|
|
DEF_ONE(Autograd)
|
|
DEF_ONE(BackendSelect)
|
|
DEF_ONE(ADInplaceOrView)
|
|
DEF_ONE(PythonTLSSnapshot)
|
|
DEF_ONE(Python)
|
|
DEF_ONE(FuncTorchDynamicLayerFrontMode)
|
|
DEF_ONE(FuncTorchDynamicLayerBackMode)
|
|
DEF_ONE(FuncTorchBatchedDecomposition)
|
|
DEF_ONE(FuncTorchBatched)
|
|
DEF_ONE(FuncTorchVmapMode)
|
|
DEF_ONE(FuncTorchGradWrapper)
|
|
DEF_ONE(PythonDispatcher)
|
|
DEF_ONE(PreDispatch)
|
|
DEF_ONE(Functionalize)
|
|
DEF_ONE(AutocastCPU)
|
|
DEF_ONE(AutocastXPU)
|
|
DEF_ONE(AutocastHPU)
|
|
DEF_ONE(AutocastIPU)
|
|
DEF_ONE(AutocastCUDA)
|
|
DEF_ONE(AutocastPrivateUse1)
|
|
// clang-format on
|
|
|
|
#define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n)
|
|
#define DEF_MULTIPLE(fullname, prefix) \
|
|
DEF_SINGLE(, fullname) \
|
|
DEF_SINGLE(, StartOf##fullname##Backends) \
|
|
C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \
|
|
DEF_SINGLE(, EndOf##fullname##Backends)
|
|
|
|
// clang-format off
|
|
C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE)
|
|
// clang-format on
|
|
|
|
#undef DEF_MULTIPLE
|
|
#undef DEF_SINGLE
|
|
;
|
|
|
|
py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
|
|
.def(py::init<c10::DispatchKey>())
|
|
.def("__or__", &c10::DispatchKeySet::operator|)
|
|
.def("__sub__", &c10::DispatchKeySet::operator-)
|
|
.def("__and__", &c10::DispatchKeySet::operator&)
|
|
.def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
|
|
.def(
|
|
"remove",
|
|
[](c10::DispatchKeySet self, c10::DispatchKey k) {
|
|
return self.remove(k);
|
|
})
|
|
.def(
|
|
"add",
|
|
[](c10::DispatchKeySet self, c10::DispatchKey k) {
|
|
return self.add(k);
|
|
})
|
|
.def("has", &c10::DispatchKeySet::has)
|
|
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });
|
|
|
|
m.attr("_dispatch_autogradother_backends") =
|
|
py::cast(c10::autogradother_backends);
|
|
|
|
m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) {
|
|
return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t);
|
|
});
|
|
|
|
m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) {
|
|
return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t);
|
|
});
|
|
|
|
m.def("_dispatch_keyset_full", []() {
|
|
return c10::DispatchKeySet(c10::DispatchKeySet::FULL);
|
|
});
|
|
|
|
m.def("_dispatch_is_alias_key", c10::isAliasDispatchKey);
|
|
|
|
m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) {
|
|
return c10::toString(keyset);
|
|
});
|
|
|
|
m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) {
|
|
return c10::getBackendKeySetFromAutograd(k);
|
|
});
|
|
|
|
m.def("_dispatch_keys", [](const at::Tensor& tensor) {
|
|
auto* impl = tensor.unsafeGetTensorImpl();
|
|
return impl->key_set();
|
|
});
|
|
m.def("_dispatch_tls_local_include_set", []() {
|
|
return c10::impl::tls_local_dispatch_key_set().included_;
|
|
});
|
|
m.def("_dispatch_tls_local_exclude_set", []() {
|
|
return c10::impl::tls_local_dispatch_key_set().excluded_;
|
|
});
|
|
m.def("_functionalization_reapply_views_tls", []() {
|
|
return at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
|
|
});
|
|
m.def(
|
|
"_dispatch_is_included_in_alias",
|
|
[](c10::DispatchKey a, c10::DispatchKey b) {
|
|
return c10::isIncludedInAlias(a, b);
|
|
});
|
|
|
|
// DEPRECATED, please don't use this. Instead use
|
|
// torch._C._ExcludeDispatchKeyGuard
|
|
py_context_manager_DEPRECATED<
|
|
c10::impl::ExcludeDispatchKeyGuard,
|
|
c10::DispatchKeySet>(m, "ExcludeDispatchKeyGuard");
|
|
|
|
py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(
|
|
m, "_ExcludeDispatchKeyGuard");
|
|
|
|
py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>(
|
|
m, "_AutoDispatchBelowAutograd");
|
|
|
|
// Prints out the name of every operator that has a kernel registered to the
|
|
// Dispatcher under [dispatch_key]. If no arguments are specified, it'll print
|
|
// out the name of every operator that the Dispatcher knows of. This can be
|
|
// useful to answer questions like "list all operators that do not have a CPU
|
|
// kernel".
|
|
m.def(
|
|
"_dispatch_print_registrations_for_dispatch_key",
|
|
[](const char* dispatch_key = "") {
|
|
auto k = std::string(dispatch_key).empty()
|
|
? c10::nullopt
|
|
: c10::make_optional(c10::parseDispatchKey(dispatch_key));
|
|
auto op_names =
|
|
c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
|
|
for (auto& op : op_names) {
|
|
std::cout << op << std::endl;
|
|
}
|
|
},
|
|
py::arg("dispatch_key") = static_cast<const char*>(""));
|
|
|
|
m.def(
|
|
"_dispatch_get_registrations_for_dispatch_key",
|
|
[](const char* dispatch_key = "") {
|
|
auto k = std::string(dispatch_key).empty()
|
|
? c10::nullopt
|
|
: c10::make_optional(c10::parseDispatchKey(dispatch_key));
|
|
auto op_names =
|
|
c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
|
|
std::vector<std::string> names;
|
|
names.reserve(op_names.size());
|
|
for (auto& op : op_names) {
|
|
names.emplace_back(
|
|
op.name +
|
|
(op.overload_name.empty() ? "" : "." + op.overload_name));
|
|
}
|
|
return names;
|
|
},
|
|
py::arg("dispatch_key") = static_cast<const char*>(""));
|
|
m.def(
|
|
"_dispatch_set_report_error_callback",
|
|
[](c10::OperatorHandle& handle, py::object callback) {
|
|
auto obj = callback.release().ptr();
|
|
auto callback_obj =
|
|
std::make_unique<c10::SafePyObject>(obj, getPyInterpreter());
|
|
handle.setReportErrorCallback_(std::move(callback_obj));
|
|
});
|
|
|
|
m.def(
|
|
"_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
|
|
|
|
m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) {
|
|
return at::functionalization::impl::replace_(a, b);
|
|
});
|
|
m.def("_propagate_xla_data", [](const at::Tensor& a, const at::Tensor& b) {
|
|
at::functionalization::impl::propagate_xla_data(a, b);
|
|
});
|
|
m.def("_commit_update", [](const at::Tensor& a) {
|
|
return at::functionalization::impl::commit_update(a);
|
|
});
|
|
|
|
m.def("_are_functorch_transforms_active", []() {
|
|
auto include_set = c10::impl::tls_local_dispatch_key_set().included_;
|
|
return (
|
|
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) ||
|
|
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
|
|
});
|
|
|
|
m.def("_get_singleton_int", [](int64_t data) {
|
|
return c10::SymInt(
|
|
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(data)));
|
|
});
|
|
|
|
using c10::impl::TorchDispatchModeKey;
|
|
py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
|
|
.value("PROXY", TorchDispatchModeKey::PROXY)
|
|
.value("FAKE", TorchDispatchModeKey::FAKE);
|
|
}
|
|
|
|
// TODO: dedupe with the kernel
|
|
void python_op_registration_trampoline_impl(
|
|
const c10::OperatorHandle& op,
|
|
c10::DispatchKey key,
|
|
torch::jit::Stack* stack) {
|
|
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
|
py::gil_scoped_acquire g;
|
|
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
|
const auto& func = python_registrations_[op.operator_name()][key];
|
|
TORCH_INTERNAL_ASSERT(func != nullptr);
|
|
auto* pyobj = func->ptr(getPyInterpreter());
|
|
TORCH_INTERNAL_ASSERT(pyobj != nullptr);
|
|
auto obj = py::reinterpret_steal<py::object>(
|
|
PyObject_Call(pyobj, args_kwargs.first.ptr(), args_kwargs.second.ptr()));
|
|
if (!obj) {
|
|
throw python_error();
|
|
}
|
|
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
|
|
}
|
|
|
|
} // namespace dispatch
|
|
} // namespace impl
|
|
} // namespace torch
|