mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63496 This PR adds a (private) enable_python_mode context manager. (see torch/utils/_python_dispatch.py). enable_python_mode accepts the type of a __torch_dispatch__ object as its argument. Whenever an operator gets called inside of the context manager, it dispatches to the __torch_dispatch__ of the passed-in type. Example usage: ``` with enable_python_mode(LoggingTensor): z = torch.empty([]) assert isinstance(z, LoggingTensor) ``` There are quite a few changes that were made to support this. First, we added TorchDispatchTypeObject, a C++ struct that represents the type of a `__torch_dispatch__` object (e.g. LoggingTensor). It holds both the PyObject* representing the class and a PyInterpreter* so we know which Python interpreter it came from. Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this is null, dispatching happens as usual. When it is non-null, we prepend the TorchDispatchTypeObject's PyObject* to the overloaded args list so that it is considered first for dispatch. To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser` works. The "overloaded args list" previously only consisted of Tensor PyObjects, but now it can have types in addition to Tensors! - We renamed `append_overloaded_arg` to `append_overloaded_arg` - We added a new `append_overloaded_type` that appends a type to overloaded_args - We added special handling in `handle_torch_dispatch_no_python_arg_parser` and `append_overloaded_arg` to handle types in addition to Tensors. Then, there is PythonMode and PythonModeTLS. - We reuse the DispatchKey::Python dispatch key as a mode key - We use PythonMode::enter and PythonMode::exit to enable/disable DispatchKey::Python and set the PythonModeTLS. - PythonModeTLS stores a TorchDispatchTypeObject as metadata. - PythonMode is in libtorch_python, and PythonModeTLS is in ATen. This split is due to the libtorch_python library boundary (because we need to save TLS in ATen/ThreadLocalState) - We modify the PythonFallbackKernel to look up the relevant TorchDispatchTypeObject (if Python Mode is active) and dispatch using it. There are two more miscellaneous changes: - internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an exclude guard. enable_python_mode currently does not handle torch.tensor and the exclude guard is to prevent a bug. Future: - This PR does not allow for the nesting of Python modes. In the future we should be able to enable this with a more sane no_dispatch API and by changing the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing. Test Plan: - new tests Reviewed By: malfet, albanD Differential Revision: D30543236 Pulled By: zou3519 fbshipit-source-id: ef5444d96a5a957d1657b7e37dce80f9a497d452
28 lines
951 B
C++
28 lines
951 B
C++
#include <torch/csrc/autograd/python_mode.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/autograd/python_variable.h>
|
|
#include <ATen/PythonModeTLS.h>
|
|
#include <c10/core/TensorImpl.h>
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
void PythonMode::enter(PyObject* type) {
|
|
if (at::impl::PythonModeTLS::get_state()) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"python mode has already been set. We do not yet support nested python ",
|
|
"mode. Please file us an issue and reset it before setting it again.")
|
|
}
|
|
// TorchDispatchTypeObject steals a reference, See NOTE [What is TorchDispatchTypeObject?]
|
|
Py_INCREF(type);
|
|
auto state = std::make_shared<c10::TorchDispatchTypeObject>(type, getPyInterpreter());
|
|
at::impl::PythonModeTLS::set_state(state);
|
|
}
|
|
|
|
void PythonMode::exit() {
|
|
TORCH_INTERNAL_ASSERT(at::impl::PythonModeTLS::get_state(), "exiting Python Mode but it wasn't set!");
|
|
at::impl::PythonModeTLS::reset_state();
|
|
}
|
|
|
|
}}
|