Add python mode (#63496)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63496

This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.

Example usage:
```
with enable_python_mode(LoggingTensor):
    z = torch.empty([])
    assert isinstance(z, LoggingTensor)
```

There are quite a few changes that were made to support this.

First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.

Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.

To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.

Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.

There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.

Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.

Test Plan: - new tests

Reviewed By: malfet, albanD

Differential Revision: D30543236

Pulled By: zou3519

fbshipit-source-id: ef5444d96a5a957d1657b7e37dce80f9a497d452
This commit is contained in:
Richard Zou 2021-08-30 18:39:50 -07:00 committed by Facebook GitHub Bot
parent ebc0aacf83
commit 4bd03b0242
19 changed files with 366 additions and 21 deletions

View File

@ -0,0 +1,26 @@
#include <ATen/PythonModeTLS.h>
namespace at { namespace impl {
thread_local std::shared_ptr<TorchDispatchTypeObject> pythonModeState;
void PythonModeTLS::set_state(const std::shared_ptr<TorchDispatchTypeObject>& state) {
pythonModeState = state;
if (state) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
} else {
PythonModeTLS::reset_state();
}
}
const std::shared_ptr<TorchDispatchTypeObject>& PythonModeTLS::get_state() {
return pythonModeState;
}
void PythonModeTLS::reset_state() {
pythonModeState.reset((TorchDispatchTypeObject*)nullptr);
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
}
} // namespace impl
} // namespace at

View File

@ -0,0 +1,17 @@
#pragma once
#include <c10/macros/Macros.h>
#include <torch/library.h>
#include <ATen/core/dispatch/Dispatcher.h>
namespace at {
namespace impl {
struct TORCH_API PythonModeTLS {
static void set_state(const std::shared_ptr<TorchDispatchTypeObject>& state);
static const std::shared_ptr<TorchDispatchTypeObject>& get_state();
static void reset_state();
};
} // namespace impl
} // namespace at

View File

@ -17,6 +17,7 @@ ThreadLocalState::ThreadLocalState()
saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks(); saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
bumped_record_all_functions_ = at::checkRecordAllFunctions(); bumped_record_all_functions_ = at::checkRecordAllFunctions();
python_mode_state_ = at::impl::PythonModeTLS::get_state();
} }
void ThreadLocalState::set_grad_mode(bool enabled) { void ThreadLocalState::set_grad_mode(bool enabled) {
@ -30,6 +31,8 @@ void ThreadLocalState::setThreadLocalState(
// restore the dispatch key set TLS at the same time. // restore the dispatch key set TLS at the same time.
c10::AutogradState::set_tls_state(state.autograd_tls_); c10::AutogradState::set_tls_state(state.autograd_tls_);
at::impl::PythonModeTLS::set_state(state.python_mode_state_);
at::set_record_function_tls_(state.rf_tls_); at::set_record_function_tls_(state.rf_tls_);
SavedTensorDefaultHooks::set_hooks( SavedTensorDefaultHooks::set_hooks(

View File

@ -6,6 +6,7 @@
#include <c10/util/ThreadLocalDebugInfo.h> #include <c10/util/ThreadLocalDebugInfo.h>
#include <ATen/record_function.h> #include <ATen/record_function.h>
#include <ATen/PythonModeTLS.h>
namespace at { namespace at {
@ -40,6 +41,8 @@ class TORCH_API ThreadLocalState {
// TLS for AutogradModes // TLS for AutogradModes
AutogradState autograd_tls_; AutogradState autograd_tls_;
std::shared_ptr<TorchDispatchTypeObject> python_mode_state_;
// TLS for saved tensors default hooks // TLS for saved tensors default hooks
std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_; std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;

View File

@ -1,9 +1,18 @@
#include <torch/library.h> #include <torch/library.h>
#include <ATen/core/dispatch/Dispatcher.h> #include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/PythonModeTLS.h>
namespace { namespace {
void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
// If Python Mode is active, use its PyInterpreter for dispatch
const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
if (maybe_python_mode_state) {
maybe_python_mode_state->pyinterpreter()->dispatch(op, stack, maybe_python_mode_state);
return;
}
// Otherwise, find a PyInterpreter on a Tensor
const auto& schema = op.schema(); const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size(); const auto num_arguments = schema.arguments().size();
// It is safe to dispatch on the very first Tensor with a pyobj_interpreter // It is safe to dispatch on the very first Tensor with a pyobj_interpreter
@ -15,7 +24,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
if (ivalue.isTensor()) { if (ivalue.isTensor()) {
auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter(); auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter();
if (interpreter) { if (interpreter) {
interpreter->dispatch(op, stack); interpreter->dispatch(op, stack, nullptr);
return; return;
} }
} else if (ivalue.isTensorList()) { } else if (ivalue.isTensorList()) {
@ -24,7 +33,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
for (const auto& nv : ivalue.toListRef()) { for (const auto& nv : ivalue.toListRef()) {
auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter(); auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter();
if (interpreter) { if (interpreter) {
interpreter->dispatch(op, stack); interpreter->dispatch(op, stack, nullptr);
return; return;
} }
} }

View File

@ -40,7 +40,8 @@ static c10::intrusive_ptr<TensorImpl> noop_detach_fn(
static void noop_dispatch_fn( static void noop_dispatch_fn(
const PyInterpreter*, const PyInterpreter*,
const c10::OperatorHandle& op, const c10::OperatorHandle& op,
torch::jit::Stack* stack) { torch::jit::Stack* stack,
const std::shared_ptr<TorchDispatchTypeObject>& type) {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
0, 0,
"attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died"); "attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died");
@ -608,6 +609,23 @@ void TensorImpl::copy_tensor_metadata(
} }
} }
TorchDispatchTypeObject::TorchDispatchTypeObject(
PyObject* type_object,
c10::impl::PyInterpreter* pyinterpreter)
: data_(type_object), pyinterpreter_(pyinterpreter) {}
TorchDispatchTypeObject::~TorchDispatchTypeObject() {
pyinterpreter_->decref(data_);
}
c10::impl::PyInterpreter* TorchDispatchTypeObject::pyinterpreter() const {
return pyinterpreter_;
}
PyObject* TorchDispatchTypeObject::ptr() const {
return data_;
}
namespace impl { namespace impl {
namespace { namespace {

View File

@ -161,6 +161,9 @@ struct C10_API AutogradMetaInterface {
virtual ~AutogradMetaInterface(); virtual ~AutogradMetaInterface();
}; };
// forward declared
struct TorchDispatchTypeObject;
namespace impl { namespace impl {
// Unfortunately, the definition of AutogradMeta lives in a separate // Unfortunately, the definition of AutogradMeta lives in a separate
@ -255,7 +258,8 @@ struct C10_API PyInterpreter {
using dispatch_sig = void( using dispatch_sig = void(
const PyInterpreter*, const PyInterpreter*,
const c10::OperatorHandle&, const c10::OperatorHandle&,
torch::jit::Stack* stack); torch::jit::Stack* stack,
const std::shared_ptr<TorchDispatchTypeObject>& type);
PyInterpreter( PyInterpreter(
name_sig* name_fn, name_sig* name_fn,
@ -299,8 +303,9 @@ struct C10_API PyInterpreter {
// Invoke the Python boxed fallback dispatch to go back into Python // Invoke the Python boxed fallback dispatch to go back into Python
__ubsan_ignore_function__ void dispatch( __ubsan_ignore_function__ void dispatch(
const c10::OperatorHandle& op, const c10::OperatorHandle& op,
torch::jit::Stack* stack) const { torch::jit::Stack* stack,
return (*dispatch_fn_)(this, op, stack); const std::shared_ptr<TorchDispatchTypeObject>& type) const {
return (*dispatch_fn_)(this, op, stack, type);
} }
// Disarm this PyInterpreter, making all of its methods noops. // Disarm this PyInterpreter, making all of its methods noops.
@ -348,6 +353,30 @@ struct C10_API NamedTensorMetaInterface {
}; };
}; };
// NOTE [What is TorchDispatchTypeObject?]
// A TorchDispatchTypeObject represents the type of a Tensor subclass that has
// a __torch_dispatch__ classmethod. Concretely, it holds the class as a
// PyObject* and a PyInterpreter* that says which python interpreter the class
// came from.
//
// See NOTE [dispatch_fn's type argument] for more details
struct C10_API TorchDispatchTypeObject {
// Steals a reference to type_object
TorchDispatchTypeObject(
PyObject* type_object,
c10::impl::PyInterpreter* pyinterpreter);
// Releases the stolen reference to type_object
~TorchDispatchTypeObject();
c10::impl::PyInterpreter* pyinterpreter() const;
PyObject* ptr() const;
private:
PyObject* data_;
c10::impl::PyInterpreter* pyinterpreter_;
};
// NOTE [ Version Counter Sharing ] // NOTE [ Version Counter Sharing ]
// //
// Every Tensor has a version counter. Version counters are incremented whenever // Every Tensor has a version counter. Version counters are incremented whenever

View File

@ -103,6 +103,7 @@ TESTS = [
"test_optim", "test_optim",
"test_functional_optim", "test_functional_optim",
"test_pytree", "test_pytree",
"test_python_dispatch",
"test_mobile_optimizer", "test_mobile_optimizer",
"test_set_default_mobile_cpu_allocator", "test_set_default_mobile_cpu_allocator",
"test_xnnpack_integration", "test_xnnpack_integration",

View File

@ -1,6 +1,7 @@
import torch import torch
from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_utils import TestCase, run_tests
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torch.utils._python_dispatch import enable_python_mode
from typing import Iterator, List from typing import Iterator, List
import logging import logging
@ -50,6 +51,9 @@ class LoggingTensor(torch.Tensor):
def wrap(e): def wrap(e):
return LoggingTensor(e) if isinstance(e, torch.Tensor) else e return LoggingTensor(e) if isinstance(e, torch.Tensor) else e
# no_dispatch is only needed if you use enable_python_mode.
# It prevents infinite recursion.
with no_dispatch():
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
return rs return rs
@ -335,6 +339,81 @@ $4 = torch._ops.aten.mul($3, tensor(2))
$5 = torch._ops.aten.mul($4, $0) $5 = torch._ops.aten.mul($4, $0)
$6 = torch._ops.aten.add_($1, $5)''') $6 = torch._ops.aten.add_($1, $5)''')
def test_enable_python_mode_error(self) -> None:
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
with enable_python_mode(torch.Tensor):
pass
z = LoggingTensor(torch.empty([]))
with self.assertRaisesRegex(ValueError, "must be the type"):
with enable_python_mode(z):
pass
def test_enable_python_mode_basic(self) -> None:
with enable_python_mode(LoggingTensor):
z = torch.empty([])
self.assertTrue(isinstance(z, LoggingTensor))
def test_enable_python_mode_unrelated_tensors(self) -> None:
x = torch.randn([])
y = torch.randn([])
with enable_python_mode(LoggingTensor):
z = x + y
self.assertTrue(isinstance(z, LoggingTensor))
def test_enable_python_mode_subclass_priority(self) -> None:
class ErrorA(RuntimeError):
pass
class ErrorB(RuntimeError):
pass
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise ErrorA
class B(A):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise ErrorB
a = A(torch.empty(1))
b = B(torch.empty(1))
with self.assertRaises(ErrorA):
a + a
# B has precedence over A due to the subclass relationship
with self.assertRaises(ErrorB):
with enable_python_mode(A):
b + b
with self.assertRaises(ErrorB):
with enable_python_mode(B):
a + a
with self.assertRaises(ErrorB):
with enable_python_mode(B):
a + b
def test_enable_python_mode_respects_no_dispatch(self) -> None:
with enable_python_mode(LoggingTensor):
z = torch.ones([2, 3])
self.assertTrue(isinstance(z, LoggingTensor))
with no_dispatch():
expected = torch.ones([2, 3])
self.assertEqual(z.elem, expected)
def test_nested_enable_python_mode(self) -> None:
with self.assertRaisesRegex(RuntimeError, "has already been set"):
with enable_python_mode(LoggingTensor):
with enable_python_mode(LoggingTensor):
pass
if __name__ == '__main__': if __name__ == '__main__':
run_tests() run_tests()

View File

@ -666,6 +666,7 @@ libtorch_python_core_sources = [
"torch/csrc/autograd/init.cpp", "torch/csrc/autograd/init.cpp",
"torch/csrc/autograd/python_anomaly_mode.cpp", "torch/csrc/autograd/python_anomaly_mode.cpp",
"torch/csrc/autograd/python_saved_variable_hooks.cpp", "torch/csrc/autograd/python_saved_variable_hooks.cpp",
"torch/csrc/autograd/python_mode.cpp",
"torch/csrc/autograd/python_cpp_function.cpp", "torch/csrc/autograd/python_cpp_function.cpp",
"torch/csrc/autograd/python_engine.cpp", "torch/csrc/autograd/python_engine.cpp",
"torch/csrc/autograd/python_function.cpp", "torch/csrc/autograd/python_function.cpp",
@ -793,6 +794,7 @@ aten_cpu_source_non_codegen_list = [
"aten/src/ATen/ParallelNativeTBB.cpp", "aten/src/ATen/ParallelNativeTBB.cpp",
"aten/src/ATen/ParallelOpenMP.cpp", "aten/src/ATen/ParallelOpenMP.cpp",
"aten/src/ATen/ParallelThreadPoolNative.cpp", "aten/src/ATen/ParallelThreadPoolNative.cpp",
"aten/src/ATen/PythonModeTLS.cpp",
"aten/src/ATen/ScalarOps.cpp", "aten/src/ATen/ScalarOps.cpp",
"aten/src/ATen/SequenceNumber.cpp", "aten/src/ATen/SequenceNumber.cpp",
"aten/src/ATen/SparseTensorImpl.cpp", "aten/src/ATen/SparseTensorImpl.cpp",

View File

@ -652,6 +652,8 @@ def __set_forward_AD_enabled(enabled: _bool) -> None: ...
def __is_forward_AD_enabled() -> _bool: ... def __is_forward_AD_enabled() -> _bool: ...
def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
def _reset_default_hooks() -> None: ... def _reset_default_hooks() -> None: ...
def _enter_python_mode(cls: Type) -> None: ...
def _exit_python_mode() -> None: ...
class _InferenceMode(object): class _InferenceMode(object):
def __init__(self, mode: _bool) -> None: ... def __init__(self, mode: _bool) -> None: ...

View File

@ -14,6 +14,7 @@
#include <torch/csrc/autograd/python_saved_variable_hooks.h> #include <torch/csrc/autograd/python_saved_variable_hooks.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h> #include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/autograd/utils/python_arg_parsing.h> #include <torch/csrc/autograd/utils/python_arg_parsing.h>
#include <torch/csrc/autograd/python_mode.h>
#include <torch/csrc/utils/pycfunction_helpers.h> #include <torch/csrc/utils/pycfunction_helpers.h>
#include <c10/core/ScalarType.h> #include <c10/core/ScalarType.h>
@ -494,6 +495,20 @@ static PyObject * python_exit_dual_level(PyObject* _unused, PyObject* args, PyOb
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject * enter_python_mode(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
PythonMode::enter(arg);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject * exit_python_mode(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
PythonMode::exit();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// autograd methods on torch._C // autograd methods on torch._C
static PyMethodDef methods[] = { // NOLINT static PyMethodDef methods[] = { // NOLINT
{"_set_grad_enabled", set_grad_enabled, METH_O, nullptr}, {"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
@ -514,6 +529,8 @@ static PyMethodDef methods[] = { // NOLINT
{"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr}, {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr}, {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
{"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr}, {"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr},
{"_enter_python_mode", enter_python_mode, METH_O, nullptr},
{"_exit_python_mode", exit_python_mode, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr} {nullptr, nullptr, 0, nullptr}
}; };

View File

@ -0,0 +1,27 @@
#include <torch/csrc/autograd/python_mode.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/autograd/python_variable.h>
#include <ATen/PythonModeTLS.h>
#include <c10/core/TensorImpl.h>
namespace torch { namespace autograd {
void PythonMode::enter(PyObject* type) {
if (at::impl::PythonModeTLS::get_state()) {
TORCH_CHECK(
false,
"python mode has already been set. We do not yet support nested python ",
"mode. Please file us an issue and reset it before setting it again.")
}
// TorchDispatchTypeObject steals a reference, See NOTE [What is TorchDispatchTypeObject?]
Py_INCREF(type);
auto state = std::make_shared<c10::TorchDispatchTypeObject>(type, getPyInterpreter());
at::impl::PythonModeTLS::set_state(state);
}
void PythonMode::exit() {
TORCH_INTERNAL_ASSERT(at::impl::PythonModeTLS::get_state(), "exiting Python Mode but it wasn't set!");
at::impl::PythonModeTLS::reset_state();
}
}}

View File

@ -0,0 +1,17 @@
#pragma once
#include <torch/csrc/python_headers.h>
#include <c10/core/TensorImpl.h>
namespace torch { namespace autograd {
struct TORCH_API PythonMode {
// Enter python mode, causing all operators to dispatch to the type's __torch_dispatch__.
// `type` is the type of a Tensor subclass that has __torch_dispatch__.
static void enter(PyObject* type);
// Exit the current python mode.
static void exit();
};
}}

View File

@ -32,6 +32,7 @@
#include <torch/library.h> #include <torch/library.h>
#include <torch/csrc/jit/python/pybind_utils.h> #include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/autograd/python_mode.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
@ -64,7 +65,12 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj) {
return; return;
pybind11::gil_scoped_acquire gil; pybind11::gil_scoped_acquire gil;
if (Py_REFCNT(pyobj) > 1) { // Two possibilities:
// 1. We are decref-ing a tensor. Then we must be careful about
// PyObject resurrection (this only applies to Tensors, see THPVariable_clear).
// 2. We are decref-ing some other Python object. We don't do
// PyObject resurrection on non-Tensors, so we just carry on as usual
if (THPVariable_Check(pyobj) && Py_REFCNT(pyobj) > 1) {
// It's still alive! This can happen if a weak ref resurrected // It's still alive! This can happen if a weak ref resurrected
// the PyObject without flipping ownership. At this point it is // the PyObject without flipping ownership. At this point it is
// too late to rescue the object, so just stub out the PyObject // too late to rescue the object, so just stub out the PyObject
@ -82,7 +88,11 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj) {
}; };
c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self); c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack); void concrete_dispatch_fn(
const c10::impl::PyInterpreter*,
const c10::OperatorHandle& op,
torch::jit::Stack* stack,
const std::shared_ptr<TorchDispatchTypeObject>& type);
class PyInterpreterHolder { class PyInterpreterHolder {
public: public:
@ -1491,7 +1501,19 @@ bool isPythonTensor(const Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
} }
void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack) { // NOTE [dispatch_fn's type argument]
// `type` is nullable and represents the PythonMode going on.
// Right now we only support a single PythonMode, but in the future we could
// change this to a stack of PythonModes.
//
// If `type` isn't null, then we consider the type for dispatch by prepending
// it to the overloaded_args list. `handle_torch_funciton_no_python_arg_parser`
// is responsible for doing overload resolution.
void concrete_dispatch_fn(
const c10::impl::PyInterpreter*,
const c10::OperatorHandle& op,
torch::jit::Stack* stack,
const std::shared_ptr<TorchDispatchTypeObject>& type) {
const auto& schema = op.schema(); const auto& schema = op.schema();
const auto num_returns = schema.returns().size(); const auto num_returns = schema.returns().size();
@ -1568,13 +1590,17 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
auto args = py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start)); auto args = py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
py::dict kwargs; py::dict kwargs;
if (type) {
append_overloaded_type(&overloaded_args, type->ptr());
}
// Find overloaded tensors // Find overloaded tensors
for (int64_t idx = 0; idx < arguments.size(); idx++) { for (int64_t idx = 0; idx < arguments.size(); idx++) {
const auto& ivalue = arguments[idx]; const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) { if (ivalue.isTensor()) {
const auto& tensor = ivalue.toTensor(); const auto& tensor = ivalue.toTensor();
if (isPythonTensor(tensor)) { if (isPythonTensor(tensor)) {
append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr()); append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
} }
} else if (ivalue.isList()) { } else if (ivalue.isList()) {
const auto& list = ivalue.toListRef(); const auto& list = ivalue.toListRef();
@ -1583,7 +1609,7 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
if (nv.isTensor()) { if (nv.isTensor()) {
const auto& tensor = nv.toTensor(); const auto& tensor = nv.toTensor();
if (isPythonTensor(tensor)) { if (isPythonTensor(tensor)) {
append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr()); append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
} }
} }
} }
@ -1633,7 +1659,7 @@ c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter
Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self))); Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t)); auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
TORCH_INTERNAL_ASSERT(isPythonTensor(self_t)); TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
append_overloaded_arg(&overloaded_args, self_p.ptr()); append_overloaded_tensor(&overloaded_args, self_p.ptr());
auto args = py::reinterpret_steal<py::object>(PyTuple_New(1)); auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr()); PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());

View File

@ -200,12 +200,28 @@ auto handle_torch_function(PyObject* self, const std::string& func_name, PyObjec
return ret.release().ptr(); return ret.release().ptr();
} }
// Note: [Overloaded args]
// An overloaded arg may be one of the following:
// - an instance of an object that has a __torch_function__ method
// - an instance of an object that has a __torch_dispatch__ classmethod
// - a class type that has a __torch_dispatch__ classmethod
//
// This function returns the type of the arg (if the arg is an instance),
// otherwise, it returns the arg.
static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
if (PyType_Check(obj_or_type)) {
return obj_or_type;
}
return (PyObject*)Py_TYPE(obj_or_type);
}
// See Note: [Overloaded args] for what they hold
auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, const char* torch_function_name) -> PyObject* { auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, const char* torch_function_name) -> PyObject* {
// overloaded_args already all have unique types // overloaded_args already all have unique types
std::vector<py::object> overloaded_types; std::vector<py::object> overloaded_types;
overloaded_types.reserve(overloaded_args.size()); overloaded_types.reserve(overloaded_args.size());
for (auto &arg : overloaded_args) { for (auto &arg : overloaded_args) {
overloaded_types.push_back(py::reinterpret_borrow<py::object>((PyObject *) Py_TYPE(arg.ptr()))); overloaded_types.push_back(py::reinterpret_borrow<py::object>(get_type_of_overloaded_arg(arg.ptr())));
} }
py::tuple py_types = py::cast(overloaded_types); py::tuple py_types = py::cast(overloaded_types);
py::object ret; py::object ret;
@ -231,7 +247,7 @@ auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &o
ss << "no implementation found for '" << module_name << "." << func_name ss << "no implementation found for '" << module_name << "." << func_name
<< "' on types that implement " << torch_function_name << ": ["; << "' on types that implement " << torch_function_name << ": [";
for (auto &arg : overloaded_args) { for (auto &arg : overloaded_args) {
ss << arg.ptr()->ob_type->tp_name; ss << PyObject_Repr(get_type_of_overloaded_arg(arg.ptr()));
if (!arg.is(overloaded_args.back())) { if (!arg.is(overloaded_args.back())) {
ss << ", "; ss << ", ";
} }
@ -328,10 +344,11 @@ auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* v
* *
*/ */
void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj) { static void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj, bool obj_is_type) {
bool class_not_seen_yet = true; bool class_not_seen_yet = true;
PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
for (auto &arg : *overloaded_args) { for (auto &arg : *overloaded_args) {
if (Py_TYPE(obj) == Py_TYPE(arg.ptr())) { if (obj_type == get_type_of_overloaded_arg(arg.ptr())) {
// obj is the same type as another parameter we've seen in a prior // obj is the same type as another parameter we've seen in a prior
// iteration of the loop over parameters so we already have an entry // iteration of the loop over parameters so we already have an entry
// with the proper __torch_function__ implementation to call, so skip // with the proper __torch_function__ implementation to call, so skip
@ -343,7 +360,7 @@ void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* o
if (class_not_seen_yet) { if (class_not_seen_yet) {
int arg_index = overloaded_args->size(); int arg_index = overloaded_args->size();
for(const auto j : c10::irange(arg_index)) { for(const auto j : c10::irange(arg_index)) {
if (PyObject_IsInstance(obj, (PyObject*)(Py_TYPE((*overloaded_args)[j].ptr())))) { if (PyObject_IsSubclass(obj_type, (PyObject*)(get_type_of_overloaded_arg((*overloaded_args)[j].ptr())))) {
// obj is a subclass of another object we've seen already so its // obj is a subclass of another object we've seen already so its
// __torch_function__ should be called first, therefore we // __torch_function__ should be called first, therefore we
// insert it into overloaded_args before the superclass // insert it into overloaded_args before the superclass
@ -358,6 +375,14 @@ void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* o
} }
} }
void append_overloaded_tensor(std::vector<py::handle>* overloaded_args, PyObject* obj) {
append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/false);
}
void append_overloaded_type(std::vector<py::handle>* overloaded_args, PyObject* obj) {
append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/true);
}
bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args) { bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args) {
if (THPVariable_CheckExact(obj)) { if (THPVariable_CheckExact(obj)) {
// torch.Tensor instances (not subclasses, except for Parameter) // torch.Tensor instances (not subclasses, except for Parameter)
@ -366,7 +391,7 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* ove
if (check_has_torch_function(obj)) { if (check_has_torch_function(obj)) {
// tensor subclasses and unrelated objects with __torch_function__ // tensor subclasses and unrelated objects with __torch_function__
append_overloaded_arg(overloaded_args, obj); append_overloaded_tensor(overloaded_args, obj);
return true; return true;
} else if (THPVariable_Check(obj)) { } else if (THPVariable_Check(obj)) {
// tensor subclasses without __torch_function__ // tensor subclasses without __torch_function__
@ -905,7 +930,7 @@ bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs,
int i = 0; int i = 0;
if (self != nullptr && check_has_torch_function(self)) { if (self != nullptr && check_has_torch_function(self)) {
append_overloaded_arg(&this->overloaded_args, self); append_overloaded_tensor(&this->overloaded_args, self);
} }
for (auto& param : params) { for (auto& param : params) {
PyObject* obj = nullptr; PyObject* obj = nullptr;

View File

@ -818,6 +818,15 @@ bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>
* 'overloaded_args': the vector to append the overloaded args * 'overloaded_args': the vector to append the overloaded args
* 'obj': the input tensor that is overloaded * 'obj': the input tensor that is overloaded
*/ */
void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj); void append_overloaded_tensor(std::vector<py::handle>* overloaded_args, PyObject* obj);
/* Given an argument that is definitely a type and is definitely overloaded,
* append it to the overloaded arguments list. Use this only with __torch_dispatch__,
* where we operate on classes that have a __torch_dispatch__ classmethod.
*
* 'overloaded_args': the vector to append the overloaded type
* 'obj': the input class that has a __torch_dispatch__ classmethod.
*/
void append_overloaded_type(std::vector<py::handle>* overloaded_args, PyObject* obj);
} // namespace torch } // namespace torch

View File

@ -267,6 +267,7 @@ Tensor internal_new_from_data(
{ {
at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
at::tracer::impl::NoTracerDispatchMode tracer_guard; at::tracer::impl::NoTracerDispatchMode tracer_guard;
c10::impl::ExcludeDispatchKeyGuard pythonmode_guard(c10::DispatchKey::Python);
// functorch uses FuncTorchDynamicLayerBackMode as a mode key to wrap all // functorch uses FuncTorchDynamicLayerBackMode as a mode key to wrap all
// tensors returned from operators in special TensorWrapper tensor extension // tensors returned from operators in special TensorWrapper tensor extension
// The problem with this is that TensorWrapper does not have storage so // The problem with this is that TensorWrapper does not have storage so

View File

@ -0,0 +1,34 @@
import torch
import contextlib
from typing import Iterator
# Context manager that causes all pytorch operators to dispatch to the passed-in
# type's __torch_dispatch__ function.
# operation that accepts no tensors but returns a tensor.
#
# enable_python_mode is affected by torch._C._DisableTorchDispatch.
#
# NB: Calling an operator inside __torch_dispatch__ does go through
# __torch_dispatch__ again. Please use _DisableTorchDispatch inside
# __torch_dispatch__ to prevent infinite recursion.
#
# TODO: Limitations and things about enable_python_mode we should fix before exposing it:
# - it currently cannot be nested. This should be simple to implement; we need a
# stack of TorchDispatchTypeObjects and the next bullet point.
# - We need a better user-facing api for torch._C._DisableTorchDispatch that
# is able to selectively disable __torch_dispatch__ of a particular class.
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
@contextlib.contextmanager
def enable_python_mode(cls) -> Iterator[None]:
if not hasattr(cls, '__torch_dispatch__'):
raise ValueError('The class passed to enable_python_mode '
'must have a __torch_dispatch__ classmethod')
if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)):
raise ValueError('The argument passed to enable_python_mode '
'must be the type of a Tensor subclass')
torch._C._enter_python_mode(cls)
try:
yield
finally:
torch._C._exit_python_mode()