mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add python mode (#63496)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63496 This PR adds a (private) enable_python_mode context manager. (see torch/utils/_python_dispatch.py). enable_python_mode accepts the type of a __torch_dispatch__ object as its argument. Whenever an operator gets called inside of the context manager, it dispatches to the __torch_dispatch__ of the passed-in type. Example usage: ``` with enable_python_mode(LoggingTensor): z = torch.empty([]) assert isinstance(z, LoggingTensor) ``` There are quite a few changes that were made to support this. First, we added TorchDispatchTypeObject, a C++ struct that represents the type of a `__torch_dispatch__` object (e.g. LoggingTensor). It holds both the PyObject* representing the class and a PyInterpreter* so we know which Python interpreter it came from. Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this is null, dispatching happens as usual. When it is non-null, we prepend the TorchDispatchTypeObject's PyObject* to the overloaded args list so that it is considered first for dispatch. To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser` works. The "overloaded args list" previously only consisted of Tensor PyObjects, but now it can have types in addition to Tensors! - We renamed `append_overloaded_arg` to `append_overloaded_arg` - We added a new `append_overloaded_type` that appends a type to overloaded_args - We added special handling in `handle_torch_dispatch_no_python_arg_parser` and `append_overloaded_arg` to handle types in addition to Tensors. Then, there is PythonMode and PythonModeTLS. - We reuse the DispatchKey::Python dispatch key as a mode key - We use PythonMode::enter and PythonMode::exit to enable/disable DispatchKey::Python and set the PythonModeTLS. - PythonModeTLS stores a TorchDispatchTypeObject as metadata. - PythonMode is in libtorch_python, and PythonModeTLS is in ATen. This split is due to the libtorch_python library boundary (because we need to save TLS in ATen/ThreadLocalState) - We modify the PythonFallbackKernel to look up the relevant TorchDispatchTypeObject (if Python Mode is active) and dispatch using it. There are two more miscellaneous changes: - internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an exclude guard. enable_python_mode currently does not handle torch.tensor and the exclude guard is to prevent a bug. Future: - This PR does not allow for the nesting of Python modes. In the future we should be able to enable this with a more sane no_dispatch API and by changing the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing. Test Plan: - new tests Reviewed By: malfet, albanD Differential Revision: D30543236 Pulled By: zou3519 fbshipit-source-id: ef5444d96a5a957d1657b7e37dce80f9a497d452
This commit is contained in:
parent
ebc0aacf83
commit
4bd03b0242
26
aten/src/ATen/PythonModeTLS.cpp
Normal file
26
aten/src/ATen/PythonModeTLS.cpp
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#include <ATen/PythonModeTLS.h>
|
||||
|
||||
namespace at { namespace impl {
|
||||
|
||||
thread_local std::shared_ptr<TorchDispatchTypeObject> pythonModeState;
|
||||
|
||||
void PythonModeTLS::set_state(const std::shared_ptr<TorchDispatchTypeObject>& state) {
|
||||
pythonModeState = state;
|
||||
if (state) {
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
|
||||
} else {
|
||||
PythonModeTLS::reset_state();
|
||||
}
|
||||
}
|
||||
|
||||
const std::shared_ptr<TorchDispatchTypeObject>& PythonModeTLS::get_state() {
|
||||
return pythonModeState;
|
||||
}
|
||||
|
||||
void PythonModeTLS::reset_state() {
|
||||
pythonModeState.reset((TorchDispatchTypeObject*)nullptr);
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
} // namespace at
|
||||
17
aten/src/ATen/PythonModeTLS.h
Normal file
17
aten/src/ATen/PythonModeTLS.h
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
namespace at {
|
||||
namespace impl {
|
||||
|
||||
struct TORCH_API PythonModeTLS {
|
||||
static void set_state(const std::shared_ptr<TorchDispatchTypeObject>& state);
|
||||
static const std::shared_ptr<TorchDispatchTypeObject>& get_state();
|
||||
static void reset_state();
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
} // namespace at
|
||||
|
|
@ -17,6 +17,7 @@ ThreadLocalState::ThreadLocalState()
|
|||
saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
|
||||
|
||||
bumped_record_all_functions_ = at::checkRecordAllFunctions();
|
||||
python_mode_state_ = at::impl::PythonModeTLS::get_state();
|
||||
}
|
||||
|
||||
void ThreadLocalState::set_grad_mode(bool enabled) {
|
||||
|
|
@ -30,6 +31,8 @@ void ThreadLocalState::setThreadLocalState(
|
|||
// restore the dispatch key set TLS at the same time.
|
||||
c10::AutogradState::set_tls_state(state.autograd_tls_);
|
||||
|
||||
at::impl::PythonModeTLS::set_state(state.python_mode_state_);
|
||||
|
||||
at::set_record_function_tls_(state.rf_tls_);
|
||||
|
||||
SavedTensorDefaultHooks::set_hooks(
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <c10/util/ThreadLocalDebugInfo.h>
|
||||
|
||||
#include <ATen/record_function.h>
|
||||
#include <ATen/PythonModeTLS.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
|
|
@ -40,6 +41,8 @@ class TORCH_API ThreadLocalState {
|
|||
// TLS for AutogradModes
|
||||
AutogradState autograd_tls_;
|
||||
|
||||
std::shared_ptr<TorchDispatchTypeObject> python_mode_state_;
|
||||
|
||||
// TLS for saved tensors default hooks
|
||||
std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,18 @@
|
|||
#include <torch/library.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/PythonModeTLS.h>
|
||||
|
||||
namespace {
|
||||
|
||||
void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
// If Python Mode is active, use its PyInterpreter for dispatch
|
||||
const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
|
||||
if (maybe_python_mode_state) {
|
||||
maybe_python_mode_state->pyinterpreter()->dispatch(op, stack, maybe_python_mode_state);
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, find a PyInterpreter on a Tensor
|
||||
const auto& schema = op.schema();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
// It is safe to dispatch on the very first Tensor with a pyobj_interpreter
|
||||
|
|
@ -15,7 +24,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
|||
if (ivalue.isTensor()) {
|
||||
auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter();
|
||||
if (interpreter) {
|
||||
interpreter->dispatch(op, stack);
|
||||
interpreter->dispatch(op, stack, nullptr);
|
||||
return;
|
||||
}
|
||||
} else if (ivalue.isTensorList()) {
|
||||
|
|
@ -24,7 +33,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
|||
for (const auto& nv : ivalue.toListRef()) {
|
||||
auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter();
|
||||
if (interpreter) {
|
||||
interpreter->dispatch(op, stack);
|
||||
interpreter->dispatch(op, stack, nullptr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,7 +40,8 @@ static c10::intrusive_ptr<TensorImpl> noop_detach_fn(
|
|||
static void noop_dispatch_fn(
|
||||
const PyInterpreter*,
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) {
|
||||
torch::jit::Stack* stack,
|
||||
const std::shared_ptr<TorchDispatchTypeObject>& type) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
0,
|
||||
"attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died");
|
||||
|
|
@ -608,6 +609,23 @@ void TensorImpl::copy_tensor_metadata(
|
|||
}
|
||||
}
|
||||
|
||||
TorchDispatchTypeObject::TorchDispatchTypeObject(
|
||||
PyObject* type_object,
|
||||
c10::impl::PyInterpreter* pyinterpreter)
|
||||
: data_(type_object), pyinterpreter_(pyinterpreter) {}
|
||||
|
||||
TorchDispatchTypeObject::~TorchDispatchTypeObject() {
|
||||
pyinterpreter_->decref(data_);
|
||||
}
|
||||
|
||||
c10::impl::PyInterpreter* TorchDispatchTypeObject::pyinterpreter() const {
|
||||
return pyinterpreter_;
|
||||
}
|
||||
|
||||
PyObject* TorchDispatchTypeObject::ptr() const {
|
||||
return data_;
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
namespace {
|
||||
|
|
|
|||
|
|
@ -161,6 +161,9 @@ struct C10_API AutogradMetaInterface {
|
|||
virtual ~AutogradMetaInterface();
|
||||
};
|
||||
|
||||
// forward declared
|
||||
struct TorchDispatchTypeObject;
|
||||
|
||||
namespace impl {
|
||||
|
||||
// Unfortunately, the definition of AutogradMeta lives in a separate
|
||||
|
|
@ -255,7 +258,8 @@ struct C10_API PyInterpreter {
|
|||
using dispatch_sig = void(
|
||||
const PyInterpreter*,
|
||||
const c10::OperatorHandle&,
|
||||
torch::jit::Stack* stack);
|
||||
torch::jit::Stack* stack,
|
||||
const std::shared_ptr<TorchDispatchTypeObject>& type);
|
||||
|
||||
PyInterpreter(
|
||||
name_sig* name_fn,
|
||||
|
|
@ -299,8 +303,9 @@ struct C10_API PyInterpreter {
|
|||
// Invoke the Python boxed fallback dispatch to go back into Python
|
||||
__ubsan_ignore_function__ void dispatch(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) const {
|
||||
return (*dispatch_fn_)(this, op, stack);
|
||||
torch::jit::Stack* stack,
|
||||
const std::shared_ptr<TorchDispatchTypeObject>& type) const {
|
||||
return (*dispatch_fn_)(this, op, stack, type);
|
||||
}
|
||||
|
||||
// Disarm this PyInterpreter, making all of its methods noops.
|
||||
|
|
@ -348,6 +353,30 @@ struct C10_API NamedTensorMetaInterface {
|
|||
};
|
||||
};
|
||||
|
||||
// NOTE [What is TorchDispatchTypeObject?]
|
||||
// A TorchDispatchTypeObject represents the type of a Tensor subclass that has
|
||||
// a __torch_dispatch__ classmethod. Concretely, it holds the class as a
|
||||
// PyObject* and a PyInterpreter* that says which python interpreter the class
|
||||
// came from.
|
||||
//
|
||||
// See NOTE [dispatch_fn's type argument] for more details
|
||||
struct C10_API TorchDispatchTypeObject {
|
||||
// Steals a reference to type_object
|
||||
TorchDispatchTypeObject(
|
||||
PyObject* type_object,
|
||||
c10::impl::PyInterpreter* pyinterpreter);
|
||||
|
||||
// Releases the stolen reference to type_object
|
||||
~TorchDispatchTypeObject();
|
||||
|
||||
c10::impl::PyInterpreter* pyinterpreter() const;
|
||||
PyObject* ptr() const;
|
||||
|
||||
private:
|
||||
PyObject* data_;
|
||||
c10::impl::PyInterpreter* pyinterpreter_;
|
||||
};
|
||||
|
||||
// NOTE [ Version Counter Sharing ]
|
||||
//
|
||||
// Every Tensor has a version counter. Version counters are incremented whenever
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ TESTS = [
|
|||
"test_optim",
|
||||
"test_functional_optim",
|
||||
"test_pytree",
|
||||
"test_python_dispatch",
|
||||
"test_mobile_optimizer",
|
||||
"test_set_default_mobile_cpu_allocator",
|
||||
"test_xnnpack_integration",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils._python_dispatch import enable_python_mode
|
||||
|
||||
from typing import Iterator, List
|
||||
import logging
|
||||
|
|
@ -50,7 +51,10 @@ class LoggingTensor(torch.Tensor):
|
|||
def wrap(e):
|
||||
return LoggingTensor(e) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
||||
# no_dispatch is only needed if you use enable_python_mode.
|
||||
# It prevents infinite recursion.
|
||||
with no_dispatch():
|
||||
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
||||
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
|
||||
return rs
|
||||
|
||||
|
|
@ -335,6 +339,81 @@ $4 = torch._ops.aten.mul($3, tensor(2))
|
|||
$5 = torch._ops.aten.mul($4, $0)
|
||||
$6 = torch._ops.aten.add_($1, $5)''')
|
||||
|
||||
def test_enable_python_mode_error(self) -> None:
|
||||
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
|
||||
with enable_python_mode(torch.Tensor):
|
||||
pass
|
||||
z = LoggingTensor(torch.empty([]))
|
||||
with self.assertRaisesRegex(ValueError, "must be the type"):
|
||||
with enable_python_mode(z):
|
||||
pass
|
||||
|
||||
def test_enable_python_mode_basic(self) -> None:
|
||||
with enable_python_mode(LoggingTensor):
|
||||
z = torch.empty([])
|
||||
self.assertTrue(isinstance(z, LoggingTensor))
|
||||
|
||||
def test_enable_python_mode_unrelated_tensors(self) -> None:
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
with enable_python_mode(LoggingTensor):
|
||||
z = x + y
|
||||
self.assertTrue(isinstance(z, LoggingTensor))
|
||||
|
||||
def test_enable_python_mode_subclass_priority(self) -> None:
|
||||
class ErrorA(RuntimeError):
|
||||
pass
|
||||
|
||||
class ErrorB(RuntimeError):
|
||||
pass
|
||||
|
||||
class A(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
raise ErrorA
|
||||
|
||||
class B(A):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
raise ErrorB
|
||||
|
||||
a = A(torch.empty(1))
|
||||
b = B(torch.empty(1))
|
||||
with self.assertRaises(ErrorA):
|
||||
a + a
|
||||
|
||||
# B has precedence over A due to the subclass relationship
|
||||
with self.assertRaises(ErrorB):
|
||||
with enable_python_mode(A):
|
||||
b + b
|
||||
with self.assertRaises(ErrorB):
|
||||
with enable_python_mode(B):
|
||||
a + a
|
||||
with self.assertRaises(ErrorB):
|
||||
with enable_python_mode(B):
|
||||
a + b
|
||||
|
||||
def test_enable_python_mode_respects_no_dispatch(self) -> None:
|
||||
with enable_python_mode(LoggingTensor):
|
||||
z = torch.ones([2, 3])
|
||||
self.assertTrue(isinstance(z, LoggingTensor))
|
||||
with no_dispatch():
|
||||
expected = torch.ones([2, 3])
|
||||
self.assertEqual(z.elem, expected)
|
||||
|
||||
def test_nested_enable_python_mode(self) -> None:
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been set"):
|
||||
with enable_python_mode(LoggingTensor):
|
||||
with enable_python_mode(LoggingTensor):
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -666,6 +666,7 @@ libtorch_python_core_sources = [
|
|||
"torch/csrc/autograd/init.cpp",
|
||||
"torch/csrc/autograd/python_anomaly_mode.cpp",
|
||||
"torch/csrc/autograd/python_saved_variable_hooks.cpp",
|
||||
"torch/csrc/autograd/python_mode.cpp",
|
||||
"torch/csrc/autograd/python_cpp_function.cpp",
|
||||
"torch/csrc/autograd/python_engine.cpp",
|
||||
"torch/csrc/autograd/python_function.cpp",
|
||||
|
|
@ -793,6 +794,7 @@ aten_cpu_source_non_codegen_list = [
|
|||
"aten/src/ATen/ParallelNativeTBB.cpp",
|
||||
"aten/src/ATen/ParallelOpenMP.cpp",
|
||||
"aten/src/ATen/ParallelThreadPoolNative.cpp",
|
||||
"aten/src/ATen/PythonModeTLS.cpp",
|
||||
"aten/src/ATen/ScalarOps.cpp",
|
||||
"aten/src/ATen/SequenceNumber.cpp",
|
||||
"aten/src/ATen/SparseTensorImpl.cpp",
|
||||
|
|
|
|||
|
|
@ -652,6 +652,8 @@ def __set_forward_AD_enabled(enabled: _bool) -> None: ...
|
|||
def __is_forward_AD_enabled() -> _bool: ...
|
||||
def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
|
||||
def _reset_default_hooks() -> None: ...
|
||||
def _enter_python_mode(cls: Type) -> None: ...
|
||||
def _exit_python_mode() -> None: ...
|
||||
|
||||
class _InferenceMode(object):
|
||||
def __init__(self, mode: _bool) -> None: ...
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
|
||||
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
||||
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
|
||||
#include <torch/csrc/autograd/python_mode.h>
|
||||
#include <torch/csrc/utils/pycfunction_helpers.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
|
|
@ -494,6 +495,20 @@ static PyObject * python_exit_dual_level(PyObject* _unused, PyObject* args, PyOb
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * enter_python_mode(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
PythonMode::enter(arg);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * exit_python_mode(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
PythonMode::exit();
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// autograd methods on torch._C
|
||||
static PyMethodDef methods[] = { // NOLINT
|
||||
{"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
|
||||
|
|
@ -514,6 +529,8 @@ static PyMethodDef methods[] = { // NOLINT
|
|||
{"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
|
||||
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
|
||||
{"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr},
|
||||
{"_enter_python_mode", enter_python_mode, METH_O, nullptr},
|
||||
{"_exit_python_mode", exit_python_mode, METH_NOARGS, nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}
|
||||
};
|
||||
|
||||
|
|
|
|||
27
torch/csrc/autograd/python_mode.cpp
Normal file
27
torch/csrc/autograd/python_mode.cpp
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
#include <torch/csrc/autograd/python_mode.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <ATen/PythonModeTLS.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
void PythonMode::enter(PyObject* type) {
|
||||
if (at::impl::PythonModeTLS::get_state()) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"python mode has already been set. We do not yet support nested python ",
|
||||
"mode. Please file us an issue and reset it before setting it again.")
|
||||
}
|
||||
// TorchDispatchTypeObject steals a reference, See NOTE [What is TorchDispatchTypeObject?]
|
||||
Py_INCREF(type);
|
||||
auto state = std::make_shared<c10::TorchDispatchTypeObject>(type, getPyInterpreter());
|
||||
at::impl::PythonModeTLS::set_state(state);
|
||||
}
|
||||
|
||||
void PythonMode::exit() {
|
||||
TORCH_INTERNAL_ASSERT(at::impl::PythonModeTLS::get_state(), "exiting Python Mode but it wasn't set!");
|
||||
at::impl::PythonModeTLS::reset_state();
|
||||
}
|
||||
|
||||
}}
|
||||
17
torch/csrc/autograd/python_mode.h
Normal file
17
torch/csrc/autograd/python_mode.h
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct TORCH_API PythonMode {
|
||||
// Enter python mode, causing all operators to dispatch to the type's __torch_dispatch__.
|
||||
// `type` is the type of a Tensor subclass that has __torch_dispatch__.
|
||||
static void enter(PyObject* type);
|
||||
|
||||
// Exit the current python mode.
|
||||
static void exit();
|
||||
};
|
||||
|
||||
}}
|
||||
|
|
@ -32,6 +32,7 @@
|
|||
|
||||
#include <torch/library.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/autograd/python_mode.h>
|
||||
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
|
@ -64,7 +65,12 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj) {
|
|||
return;
|
||||
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
if (Py_REFCNT(pyobj) > 1) {
|
||||
// Two possibilities:
|
||||
// 1. We are decref-ing a tensor. Then we must be careful about
|
||||
// PyObject resurrection (this only applies to Tensors, see THPVariable_clear).
|
||||
// 2. We are decref-ing some other Python object. We don't do
|
||||
// PyObject resurrection on non-Tensors, so we just carry on as usual
|
||||
if (THPVariable_Check(pyobj) && Py_REFCNT(pyobj) > 1) {
|
||||
// It's still alive! This can happen if a weak ref resurrected
|
||||
// the PyObject without flipping ownership. At this point it is
|
||||
// too late to rescue the object, so just stub out the PyObject
|
||||
|
|
@ -82,7 +88,11 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj) {
|
|||
};
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
|
||||
void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
||||
void concrete_dispatch_fn(
|
||||
const c10::impl::PyInterpreter*,
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack,
|
||||
const std::shared_ptr<TorchDispatchTypeObject>& type);
|
||||
|
||||
class PyInterpreterHolder {
|
||||
public:
|
||||
|
|
@ -1491,7 +1501,19 @@ bool isPythonTensor(const Tensor& tensor) {
|
|||
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
|
||||
}
|
||||
|
||||
void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
// NOTE [dispatch_fn's type argument]
|
||||
// `type` is nullable and represents the PythonMode going on.
|
||||
// Right now we only support a single PythonMode, but in the future we could
|
||||
// change this to a stack of PythonModes.
|
||||
//
|
||||
// If `type` isn't null, then we consider the type for dispatch by prepending
|
||||
// it to the overloaded_args list. `handle_torch_funciton_no_python_arg_parser`
|
||||
// is responsible for doing overload resolution.
|
||||
void concrete_dispatch_fn(
|
||||
const c10::impl::PyInterpreter*,
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack,
|
||||
const std::shared_ptr<TorchDispatchTypeObject>& type) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
|
||||
|
|
@ -1568,13 +1590,17 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
|
|||
auto args = py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
|
||||
py::dict kwargs;
|
||||
|
||||
if (type) {
|
||||
append_overloaded_type(&overloaded_args, type->ptr());
|
||||
}
|
||||
|
||||
// Find overloaded tensors
|
||||
for (int64_t idx = 0; idx < arguments.size(); idx++) {
|
||||
const auto& ivalue = arguments[idx];
|
||||
if (ivalue.isTensor()) {
|
||||
const auto& tensor = ivalue.toTensor();
|
||||
if (isPythonTensor(tensor)) {
|
||||
append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
|
||||
append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
|
||||
}
|
||||
} else if (ivalue.isList()) {
|
||||
const auto& list = ivalue.toListRef();
|
||||
|
|
@ -1583,7 +1609,7 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
|
|||
if (nv.isTensor()) {
|
||||
const auto& tensor = nv.toTensor();
|
||||
if (isPythonTensor(tensor)) {
|
||||
append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
|
||||
append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1633,7 +1659,7 @@ c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter
|
|||
Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
|
||||
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
|
||||
TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
|
||||
append_overloaded_arg(&overloaded_args, self_p.ptr());
|
||||
append_overloaded_tensor(&overloaded_args, self_p.ptr());
|
||||
auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
|
||||
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
|
||||
|
||||
|
|
|
|||
|
|
@ -200,12 +200,28 @@ auto handle_torch_function(PyObject* self, const std::string& func_name, PyObjec
|
|||
return ret.release().ptr();
|
||||
}
|
||||
|
||||
// Note: [Overloaded args]
|
||||
// An overloaded arg may be one of the following:
|
||||
// - an instance of an object that has a __torch_function__ method
|
||||
// - an instance of an object that has a __torch_dispatch__ classmethod
|
||||
// - a class type that has a __torch_dispatch__ classmethod
|
||||
//
|
||||
// This function returns the type of the arg (if the arg is an instance),
|
||||
// otherwise, it returns the arg.
|
||||
static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
|
||||
if (PyType_Check(obj_or_type)) {
|
||||
return obj_or_type;
|
||||
}
|
||||
return (PyObject*)Py_TYPE(obj_or_type);
|
||||
}
|
||||
|
||||
// See Note: [Overloaded args] for what they hold
|
||||
auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, const char* torch_function_name) -> PyObject* {
|
||||
// overloaded_args already all have unique types
|
||||
std::vector<py::object> overloaded_types;
|
||||
overloaded_types.reserve(overloaded_args.size());
|
||||
for (auto &arg : overloaded_args) {
|
||||
overloaded_types.push_back(py::reinterpret_borrow<py::object>((PyObject *) Py_TYPE(arg.ptr())));
|
||||
overloaded_types.push_back(py::reinterpret_borrow<py::object>(get_type_of_overloaded_arg(arg.ptr())));
|
||||
}
|
||||
py::tuple py_types = py::cast(overloaded_types);
|
||||
py::object ret;
|
||||
|
|
@ -231,7 +247,7 @@ auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &o
|
|||
ss << "no implementation found for '" << module_name << "." << func_name
|
||||
<< "' on types that implement " << torch_function_name << ": [";
|
||||
for (auto &arg : overloaded_args) {
|
||||
ss << arg.ptr()->ob_type->tp_name;
|
||||
ss << PyObject_Repr(get_type_of_overloaded_arg(arg.ptr()));
|
||||
if (!arg.is(overloaded_args.back())) {
|
||||
ss << ", ";
|
||||
}
|
||||
|
|
@ -328,10 +344,11 @@ auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* v
|
|||
*
|
||||
*/
|
||||
|
||||
void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj) {
|
||||
static void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj, bool obj_is_type) {
|
||||
bool class_not_seen_yet = true;
|
||||
PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
|
||||
for (auto &arg : *overloaded_args) {
|
||||
if (Py_TYPE(obj) == Py_TYPE(arg.ptr())) {
|
||||
if (obj_type == get_type_of_overloaded_arg(arg.ptr())) {
|
||||
// obj is the same type as another parameter we've seen in a prior
|
||||
// iteration of the loop over parameters so we already have an entry
|
||||
// with the proper __torch_function__ implementation to call, so skip
|
||||
|
|
@ -343,7 +360,7 @@ void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* o
|
|||
if (class_not_seen_yet) {
|
||||
int arg_index = overloaded_args->size();
|
||||
for(const auto j : c10::irange(arg_index)) {
|
||||
if (PyObject_IsInstance(obj, (PyObject*)(Py_TYPE((*overloaded_args)[j].ptr())))) {
|
||||
if (PyObject_IsSubclass(obj_type, (PyObject*)(get_type_of_overloaded_arg((*overloaded_args)[j].ptr())))) {
|
||||
// obj is a subclass of another object we've seen already so its
|
||||
// __torch_function__ should be called first, therefore we
|
||||
// insert it into overloaded_args before the superclass
|
||||
|
|
@ -358,6 +375,14 @@ void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* o
|
|||
}
|
||||
}
|
||||
|
||||
void append_overloaded_tensor(std::vector<py::handle>* overloaded_args, PyObject* obj) {
|
||||
append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/false);
|
||||
}
|
||||
|
||||
void append_overloaded_type(std::vector<py::handle>* overloaded_args, PyObject* obj) {
|
||||
append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/true);
|
||||
}
|
||||
|
||||
bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args) {
|
||||
if (THPVariable_CheckExact(obj)) {
|
||||
// torch.Tensor instances (not subclasses, except for Parameter)
|
||||
|
|
@ -366,7 +391,7 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* ove
|
|||
|
||||
if (check_has_torch_function(obj)) {
|
||||
// tensor subclasses and unrelated objects with __torch_function__
|
||||
append_overloaded_arg(overloaded_args, obj);
|
||||
append_overloaded_tensor(overloaded_args, obj);
|
||||
return true;
|
||||
} else if (THPVariable_Check(obj)) {
|
||||
// tensor subclasses without __torch_function__
|
||||
|
|
@ -905,7 +930,7 @@ bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs,
|
|||
|
||||
int i = 0;
|
||||
if (self != nullptr && check_has_torch_function(self)) {
|
||||
append_overloaded_arg(&this->overloaded_args, self);
|
||||
append_overloaded_tensor(&this->overloaded_args, self);
|
||||
}
|
||||
for (auto& param : params) {
|
||||
PyObject* obj = nullptr;
|
||||
|
|
|
|||
|
|
@ -818,6 +818,15 @@ bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>
|
|||
* 'overloaded_args': the vector to append the overloaded args
|
||||
* 'obj': the input tensor that is overloaded
|
||||
*/
|
||||
void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj);
|
||||
void append_overloaded_tensor(std::vector<py::handle>* overloaded_args, PyObject* obj);
|
||||
|
||||
/* Given an argument that is definitely a type and is definitely overloaded,
|
||||
* append it to the overloaded arguments list. Use this only with __torch_dispatch__,
|
||||
* where we operate on classes that have a __torch_dispatch__ classmethod.
|
||||
*
|
||||
* 'overloaded_args': the vector to append the overloaded type
|
||||
* 'obj': the input class that has a __torch_dispatch__ classmethod.
|
||||
*/
|
||||
void append_overloaded_type(std::vector<py::handle>* overloaded_args, PyObject* obj);
|
||||
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -267,6 +267,7 @@ Tensor internal_new_from_data(
|
|||
{
|
||||
at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
|
||||
at::tracer::impl::NoTracerDispatchMode tracer_guard;
|
||||
c10::impl::ExcludeDispatchKeyGuard pythonmode_guard(c10::DispatchKey::Python);
|
||||
// functorch uses FuncTorchDynamicLayerBackMode as a mode key to wrap all
|
||||
// tensors returned from operators in special TensorWrapper tensor extension
|
||||
// The problem with this is that TensorWrapper does not have storage so
|
||||
|
|
|
|||
34
torch/utils/_python_dispatch.py
Normal file
34
torch/utils/_python_dispatch.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import torch
|
||||
import contextlib
|
||||
from typing import Iterator
|
||||
|
||||
# Context manager that causes all pytorch operators to dispatch to the passed-in
|
||||
# type's __torch_dispatch__ function.
|
||||
# operation that accepts no tensors but returns a tensor.
|
||||
#
|
||||
# enable_python_mode is affected by torch._C._DisableTorchDispatch.
|
||||
#
|
||||
# NB: Calling an operator inside __torch_dispatch__ does go through
|
||||
# __torch_dispatch__ again. Please use _DisableTorchDispatch inside
|
||||
# __torch_dispatch__ to prevent infinite recursion.
|
||||
#
|
||||
# TODO: Limitations and things about enable_python_mode we should fix before exposing it:
|
||||
# - it currently cannot be nested. This should be simple to implement; we need a
|
||||
# stack of TorchDispatchTypeObjects and the next bullet point.
|
||||
# - We need a better user-facing api for torch._C._DisableTorchDispatch that
|
||||
# is able to selectively disable __torch_dispatch__ of a particular class.
|
||||
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
|
||||
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
|
||||
@contextlib.contextmanager
|
||||
def enable_python_mode(cls) -> Iterator[None]:
|
||||
if not hasattr(cls, '__torch_dispatch__'):
|
||||
raise ValueError('The class passed to enable_python_mode '
|
||||
'must have a __torch_dispatch__ classmethod')
|
||||
if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)):
|
||||
raise ValueError('The argument passed to enable_python_mode '
|
||||
'must be the type of a Tensor subclass')
|
||||
torch._C._enter_python_mode(cls)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._C._exit_python_mode()
|
||||
Loading…
Reference in New Issue
Block a user