pytorch/torch/csrc/utils/python_raii.h
Richard Zou 08fb648fe1 Add mechanism to turn any RAII guard into a Python Context Manager (#102037)
This PR:
- adds a mechanism to turn any RAII guard into a Python Context Manager
- turns ExcludeDispatchKeyGuard into a context manager, and purges usages
of the older torch._C.ExcludeDispatchKeyGuard from the codebase.

The mechanism is that given a RAII guard, we construct a context
manager object that holds an optional guard. When we enter the context
manager we populate the guard, when we exit we reset it.

We don't delete torch._C.ExcludeDispatchKeyGuard for BC reasons (people
are using it in fbcode). If this code actually sticks
(it is using C++17 and that worries me a bit), then I'll apply the
change to other RAII guards we have, otherwise, we can write our own
std::apply.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102037
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
2023-05-24 14:20:52 +00:00

47 lines
1.2 KiB
C++

#include <c10/util/Optional.h>
#include <torch/csrc/utils/pybind.h>
#include <tuple>
namespace torch {
namespace impl {
template <typename GuardT, typename... Args>
struct RAIIContextManager {
explicit RAIIContextManager(Args&&... args)
: args_(std::forward<Args>(args)...) {}
void enter() {
auto emplace = [&](Args... args) {
return guard_.emplace(std::forward<Args>(args)...);
};
std::apply(std::move(emplace), args_);
}
void exit() {
guard_ = c10::nullopt;
}
private:
c10::optional<GuardT> guard_;
std::tuple<Args...> args_;
};
// Turns a C++ RAII guard into a Python context manager.
// See _ExcludeDispatchKeyGuard in python_dispatch.cpp for example.
template <typename GuardT, typename... GuardArgs>
void py_context_manager(const py::module& m, const char* name) {
using ContextManagerT = RAIIContextManager<GuardT, GuardArgs...>;
py::class_<ContextManagerT>(m, name)
.def(py::init<GuardArgs...>())
.def("__enter__", [](ContextManagerT& guard) { guard.enter(); })
.def(
"__exit__",
[](ContextManagerT& guard,
py::object exc_type,
py::object exc_value,
py::object traceback) { guard.exit(); });
}
} // namespace impl
} // namespace torch