diff --git a/aten/src/ATen/core/PythonFallbackKernel.cpp b/aten/src/ATen/core/PythonFallbackKernel.cpp new file mode 100644 index 00000000000..276eabfe458 --- /dev/null +++ b/aten/src/ATen/core/PythonFallbackKernel.cpp @@ -0,0 +1,40 @@ +#include +#include + +namespace { + +void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_arguments = schema.arguments().size(); + // It is safe to dispatch on the very first Tensor with a pyobj_interpreter + // without checking the interpreters of any of the arguments, because when + // we actually run dispatch(), we will take out PyObjects in the context + // of that interpreter, and this will ensure that everyone is on the same + // interpreter. + for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) { + if (ivalue.isTensor()) { + auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter(); + if (interpreter) { + interpreter->dispatch(op, stack); + return; + } + } else if (ivalue.isTensorList()) { + // NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef + // is not a thing) + for (const auto& nv : ivalue.toListRef()) { + auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter(); + if (interpreter) { + interpreter->dispatch(op, stack); + return; + } + } + } + } + TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)"); +} + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(_, Python, m) { + m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallback>()); +} diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 02052a746b1..8f8e344ebaa 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -61,6 +61,9 @@ const char* toString(DispatchKey t) { case DispatchKey::NestedTensor: return "NestedTensor"; + case DispatchKey::Python: + return "Python"; + case DispatchKey::PrivateUse1: return "PrivateUse1"; case DispatchKey::PrivateUse2: diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index a3de003c63e..663c13e4b82 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -110,6 +110,7 @@ enum class DispatchKey : uint8_t { SparseCsrCUDA, NestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor + // Here are reserved backends for user-defined backends, see Note [Private use // DispatchKey] // To see some example about how to use this, check out MSNPU @@ -119,6 +120,8 @@ enum class DispatchKey : uint8_t { // Define an alias key to represent end of backend dispatch keys. // If you add new backend keys after PrivateUse3, please also update it here. + // (But you shouldn't: private use keys should have higher precedence than + // all built-in keys) EndOfBackendKeys = PrivateUse3, // In some situations, it is not immediately obvious what the correct @@ -128,6 +131,7 @@ enum class DispatchKey : uint8_t { // correct backend. BackendSelect, + Python, FuncTorchPython, // See Note [Out-of-tree vmap+grad prototype] // The named dispatch key is set for any tensors with named dimensions. diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 95005582987..33de1513851 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -31,9 +31,28 @@ static void noop_decref_fn(const PyInterpreter*, PyObject*) { // no-op } +static c10::intrusive_ptr noop_detach_fn( + const PyInterpreter*, + const TensorImpl*) { + TORCH_INTERNAL_ASSERT( + 0, + "attempted to detach (shallow_copy_and_detach) Tensor with nontrivial PyObject after corresponding interpreter died"); +} + +static void noop_dispatch_fn( + const PyInterpreter*, + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + TORCH_INTERNAL_ASSERT( + 0, + "attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died"); +} + void PyInterpreter::disarm() noexcept { name_fn_ = &noop_name_fn; decref_fn_ = &noop_decref_fn; + detach_fn_ = &noop_detach_fn; + dispatch_fn_ = &noop_dispatch_fn; } } // namespace impl @@ -95,6 +114,23 @@ TensorImpl::TensorImpl( data_type, storage.device()) {} +// [Note: Python key removal] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// In most constructors for TensorImpl, you will see Python key is removed from +// the passed in DispatchKeySet. Why? +// +// INVARIANT: Python dispatch key is set iff PyObject for the Tensor has a +// nontrivial __torch_dispatch__ implementation. +// +// When a fresh TensorImpl is created, there is *no* PyObject (this only gets +// initialized lazily at the first point in time the Tensor passes into Python). +// So we would violate the invariant. +// +// In practice, what will happen shortly afterwards is that the TensorImpl +// will get its PyObject initialized by Tensor._make_subclass; at this point +// the Python dispatch key will be set and all is well. The point is to delay +// the dispatch key setting until that point. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorImpl::TensorImpl( ImplType type, @@ -108,7 +144,8 @@ TensorImpl::TensorImpl( numel_(0), data_type_(data_type), device_opt_(storage_.device()), - key_set_(key_set) { + key_set_(key_set.remove( + DispatchKey::Python)) { // See [Note: Python key removal] init_bitfields(); // Inference tensor doesn't have version counter. if (!is_inference()) { @@ -153,6 +190,9 @@ TensorImpl::TensorImpl( key_set = key_set | getAutocastRelatedKeySetFromBackend(k); + key_set = + key_set.remove(DispatchKey::Python); // See [Note: Python key removal] + // Inference tensor doesn't have autograd related keys. if (inference_mode) { // See Note [Expected TLS state in InferenceMode] for why we exclude @@ -459,6 +499,17 @@ c10::AutogradMetaInterface* TensorImpl::autograd_meta() const { c10::intrusive_ptr TensorImpl::shallow_copy_and_detach( const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const { + if (key_set_.has(DispatchKey::Python) && + !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { + auto r = pyobj_interpreter_.load(std::memory_order_acquire)->detach(this); + if (r) { + r->set_version_counter(version_counter); + r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); + return r; + } + // otherwise just copy the TensorImpl and not the PyObject. Since + // the interpreter is dead no one can call us out on it + } auto impl = c10::make_intrusive( // No need to populate Storage; copy_tensor_metadata will do it for us. key_set_, @@ -477,6 +528,17 @@ c10::intrusive_ptr TensorImpl::shallow_copy_and_detach( c10::intrusive_ptr TensorImpl::shallow_copy_and_detach( c10::VariableVersion&& version_counter, bool allow_tensor_metadata_change) const { + if (key_set_.has(DispatchKey::Python) && + !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { + auto r = pyobj_interpreter_.load(std::memory_order_acquire)->detach(this); + if (r) { + r->set_version_counter(std::move(version_counter)); + r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); + return r; + } + // otherwise just copy the TensorImpl and not the PyObject. Since + // the interpreter is dead no one can call us out on it + } auto impl = c10::make_intrusive( // No need to populate Storage; copy_tensor_metadata will do it for us. key_set_, @@ -501,7 +563,7 @@ void TensorImpl::copy_tensor_metadata_except_version_counter( dest_impl->storage_offset_ = src_impl->storage_offset_; dest_impl->data_type_ = src_impl->data_type_; dest_impl->device_opt_ = src_impl->device_opt_; - dest_impl->key_set_ = src_impl->key_set_; + dest_impl->key_set_ = src_impl->key_set_.remove(DispatchKey::Python); dest_impl->is_contiguous_ = src_impl->is_contiguous_; dest_impl->has_contiguity_ = src_impl->has_contiguity_; dest_impl->is_channels_last_contiguous_ = diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index b4de8a781e0..f03ac628537 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -42,7 +42,18 @@ class Tensor; namespace c10 { class Scalar; +struct IValue; struct Storage; +class OperatorHandle; +} // namespace c10 + +namespace torch { +namespace jit { +using Stack = std::vector; +} +} // namespace torch + +namespace c10 { /** * A utility function to convert vector to vector. @@ -239,14 +250,27 @@ struct PyInterpreter; struct C10_API PyInterpreter { using name_sig = std::string(const PyInterpreter*); using decref_sig = void(const PyInterpreter*, PyObject*); + using detach_sig = + c10::intrusive_ptr(const PyInterpreter*, const TensorImpl*); + using dispatch_sig = void( + const PyInterpreter*, + const c10::OperatorHandle&, + torch::jit::Stack* stack); - PyInterpreter(name_sig* name_fn, decref_sig* decref_fn) - : name_fn_(name_fn), decref_fn_(decref_fn) {} + PyInterpreter( + name_sig* name_fn, + decref_sig* decref_fn, + detach_sig* detach, + dispatch_sig* dispatch) + : name_fn_(name_fn), + decref_fn_(decref_fn), + detach_fn_(detach), + dispatch_fn_(dispatch) {} - // For debugging purposes only name_sig* name_fn_; - decref_sig* decref_fn_; + detach_sig* detach_fn_; + dispatch_sig* dispatch_fn_; // UBSAN suppression fixes: "call to function // (anonymous namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*, @@ -254,6 +278,7 @@ struct C10_API PyInterpreter { // c10::impl::PyInterpreter *, _object *)'" See // https://github.com/google/sanitizers/issues/911 + // Report the name of this interpreter __ubsan_ignore_function__ std::string name() const { return (*name_fn_)(this); } @@ -263,6 +288,21 @@ struct C10_API PyInterpreter { return (*decref_fn_)(this, pyobj); } + // Perform a detach by deferring to the __torch_dispatch__ implementation of + // detach, which will also arrange for the PyObject to get copied in this + // situation + __ubsan_ignore_function__ c10::intrusive_ptr detach( + const TensorImpl* self) const { + return (*detach_fn_)(this, self); + } + + // Invoke the Python boxed fallback dispatch to go back into Python + __ubsan_ignore_function__ void dispatch( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) const { + return (*dispatch_fn_)(this, op, stack); + } + // Disarm this PyInterpreter, making all of its methods noops. // Because the function pointers are raw pointers (not atomics), // a disarm() invocation that is concurrent with active destructors @@ -1321,6 +1361,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } } + void set_python_dispatch(bool k) { + if (k) { + key_set_ = key_set_.add(DispatchKey::Python); + } else { + key_set_ = key_set_.remove(DispatchKey::Python); + } + } + + bool is_python_dispatch() const { + return key_set_.has(DispatchKey::Python); + } + /** * Return the pointer to named tensor metadata. */ @@ -1510,6 +1562,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { pyobj_ = pyobj; } + // Query the PyObject interpreter. This may return null if there is no + // interpreter. This is racy! + impl::PyInterpreter* pyobj_interpreter() { + return pyobj_interpreter_.load(std::memory_order_acquire); + } + // Test the interpreter tag. If tagged for the current interpreter, return // a non-nullopt (but possibly null) PyObject. If (possibly) untagged, // returns a nullopt. If it is definitely invalid, raises an error. diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py new file mode 100644 index 00000000000..5a55b8e28b3 --- /dev/null +++ b/test/test_python_dispatch.py @@ -0,0 +1,257 @@ +import torch +from torch.testing._internal.common_utils import TestCase, run_tests +from torch.utils._pytree import tree_map + +from typing import Iterator, List +import logging +import contextlib + +# TODO: move this into library proper +@contextlib.contextmanager +def no_dispatch() -> Iterator[None]: + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + + +# How the chain of calls works for LoggingTensor: +# 1. Call torch.sin +# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely +# 3. Enter dispatcher, wind your way through Autograd +# 4. Hit Python dispatch key, call __torch_dispatch__ + +# TODO: TensorBase should work +class LoggingTensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem'] + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + # The wrapping tensor (LoggingTensor) is just a meta tensor, so it + # doesn't hold any memory (meta tensor is generally the preferred type + # of tensor you want to make a subclass from)... + r = torch.Tensor._make_subclass(cls, elem.to('meta'), elem.requires_grad) + # ...the real tensor is held as an element on the tensor. + r.elem = elem + return r + + def __repr__(self): + return f"LoggingTensor({self.elem})" + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def unwrap(e): + return e.elem if isinstance(e, LoggingTensor) else e + + def wrap(e): + return LoggingTensor(e) if isinstance(e, torch.Tensor) else e + + # TODO: handle kwargs + assert not kwargs + rs = tree_map(wrap, func(*tree_map(unwrap, args))) + logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, rs) + return rs + +# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list +class LoggingTensorHandler(logging.Handler): + log_list: List[str] + next_shortid: int + + def __init__(self, log_list: List[str]) -> None: + logging.Handler.__init__(self) + self.log_list = log_list + self.next_shortid = 0 + + # WARNING: not deterministic over multiple threads, this matters for + # autograd + def _shortid(self, o: object) -> int: + if not hasattr(o, '_shortid'): + o._shortid = self.next_shortid + self.next_shortid += 1 + return o._shortid + + def _fmt(self, a: object) -> str: + return f'${self._shortid(a)}' if isinstance(a, LoggingTensor) else repr(a) + + def emit(self, record): + fmt_args = "(" + ", ".join(self._fmt(a) for a in record.args[0]) + ")" + fmt_rets = ", ".join(self._fmt(a) for a in record.args[1]) \ + if isinstance(record.args[1], (list, tuple)) else self._fmt(record.args[1]) + self.log_list.append(f'{fmt_rets} = {record.msg}{fmt_args}') + +def log_input(name: str, var: object): + logging.getLogger("LoggingTensor").info("input", (name,), (var,)) + +@contextlib.contextmanager +def capture_logs() -> Iterator[List[str]]: + logger = logging.getLogger("LoggingTensor") + log_list = [] + handler = LoggingTensorHandler(log_list) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + try: + yield log_list + finally: + logger.removeHandler(handler) + +class TestPythonDispatch(TestCase): + def test_basic(self) -> None: + with capture_logs() as logs: + x = LoggingTensor(torch.tensor([3.0], requires_grad=True)) + log_input("x", x) + y = x * x + saved_x = y.grad_fn._saved_self + grad_y = LoggingTensor(torch.tensor([1.0])) + log_input("grad_y", grad_y) + g, = torch.autograd.grad((y,), (x,), (grad_y,)) + + self.assertEqual(g.elem, torch.tensor([6.0])) + with torch.no_grad(): + self.assertEqual(saved_x, x) + self.assertEqual(saved_x._version, x._version) + x.add_(2) + self.assertEqual(saved_x, x) + # TODO: figure out why broken + # self.assertEqual(saved_x._version, x._version) + self.assertExpectedInline('\n'.join(logs), '''\ +$0 = input('x') +$1 = torch._ops.aten.mul($0, $0) +$2 = input('grad_y') +$3 = torch._ops.aten.mul($2, $0) +$4 = torch._ops.aten.mul($2, $0) +$5 = torch._ops.aten.add($4, $3, 1)''') + + def test_out(self) -> None: + with capture_logs() as logs: + x = LoggingTensor(torch.ones(1)) + y = LoggingTensor(torch.zeros(1)) + log_input("x", x) + log_input("y", y) + torch.abs(x, out=y) + + self.assertEqual(y.elem, torch.ones(1)) + # TODO: arguably this shouldn't pass and we should complain + # that out isn't a kwarg + self.assertExpectedInline('\n'.join(logs), '''\ +$0 = input('x') +$1 = input('y') +$2 = torch._ops.aten.abs($0, $1)''') + + def test_list_ret(self) -> None: + # test all sequence types are permissible returns + for list_type in (list, tuple): + class A(torch._C._TensorBase): + @staticmethod + def __new__(cls, elem): + return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + if func == torch.ops.aten.split: + with no_dispatch(): + return list_type(torch.split(*args)) + else: + raise AssertionError(f"unrecognized func: {func}") + + self.assertEqual( + torch.split(A(torch.tensor([0, 1])), 2), + torch.split(torch.tensor([0, 1]), 2) + ) + + def test_invalid_ret(self) -> None: + # test invalid return gets reasonable error message + class A(torch._C._TensorBase): + @staticmethod + def __new__(cls, elem): + return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + return "arf" + + self.assertExpectedRaisesInline( + RuntimeError, lambda: A(torch.zeros(1)).neg(), + """Unable to cast Python instance of type to C++ type 'at::Tensor'""" + ) + self.assertExpectedRaisesInline( + RuntimeError, lambda: A(torch.zeros(1)).detach(), + """detach returned invalid type str, expected Tensor""" + ) + + def test_metadata_change_not_allowed(self) -> None: + x = LoggingTensor(torch.ones(1)) + y = x.data + self.assertIsInstance(y, LoggingTensor) + self.assertRaises(RuntimeError, lambda: y.resize_(4)) + + def test_version(self) -> None: + x = LoggingTensor(torch.ones(1)) + prev_vc = x._version + x.detach().add_(2) + cur_vc = x._version + self.assertNotEqual(prev_vc, cur_vc) + x.data.add_(2) + self.assertEqual(cur_vc, x._version) + + def test_format(self) -> None: + x = LoggingTensor(torch.ones(1)) + s1 = str(x) + s2 = repr(x) + s3 = f"{x}" + self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""") + self.assertEqual(s1, s2) + self.assertEqual(s1, s3) + + def test_custom_autograd(self) -> None: + escape = [None] + + class Square(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + y = x ** 2 + ctx.save_for_backward(x) + return y + + @staticmethod + def backward(ctx, grad_output): + assert isinstance(grad_output, LoggingTensor) + x, = ctx.saved_tensors + assert isinstance(x, LoggingTensor) + escape[0] = x + return grad_output * 2 * x + + with capture_logs() as logs: + x = LoggingTensor(torch.ones(1, requires_grad=True)) + log_input("x", x) + x.grad = LoggingTensor(torch.zeros(1)) + log_input("x.grad", x.grad) + y = Square.apply(x) + grad_output = LoggingTensor(torch.ones(1)) + log_input("grad_output", grad_output) + y.backward(grad_output) + + with torch.no_grad(): + self.assertEqual(escape[0], x) + self.assertEqual(escape[0]._version, x._version) + # TODO: figure out why x.requires_grad = False doesn't + # trigger an error for LoggingTensor + x.add_(2) + self.assertEqual(escape[0], x) + # TODO: figure out why this is broken + # self.assertEqual(escape[0]._version, x._version) + + self.assertExpectedInline('\n'.join(logs), '''\ +$0 = input('x') +$1 = input('x.grad') +$2 = torch._ops.aten.pow($0, 2) +$3 = input('grad_output') +$4 = torch._ops.aten.mul($3, tensor(2)) +$5 = torch._ops.aten.mul($4, $0) +$6 = torch._ops.aten.add_($1, $5, 1)''') + + +if __name__ == '__main__': + run_tests() diff --git a/torch/_tensor.py b/torch/_tensor.py index d979e6128d8..29c08139050 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1026,7 +1026,7 @@ def _convert(ret, cls): if cls is Tensor: return ret - if isinstance(ret, Tensor): + if isinstance(ret, Tensor) and not isinstance(ret, cls): ret = ret.as_subclass(cls) if isinstance(ret, (tuple, list)): diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 6163d23bd7b..56c4ae4c4b6 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -17,6 +17,12 @@ #include +struct DisableTorchDispatch { + DisableTorchDispatch() : guard_(c10::DispatchKey::Python) { + } + c10::impl::ExcludeDispatchKeyGuard guard_; +}; + PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { using namespace torch::autograd::profiler; auto tensor_module = THPObjectPtr(PyImport_ImportModule("torch._tensor")); @@ -254,6 +260,9 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { py::class_(_C_m, "_InferenceMode") .def(py::init()); + py::class_(_C_m, "_DisableTorchDispatch") + .def(py::init<>()); + Py_RETURN_TRUE; } diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index d1909547dd6..83cfee991ad 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -30,6 +30,10 @@ #include #include +#include +#include + + #include #include @@ -77,12 +81,17 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj) { Py_DECREF(pyobj); }; +c10::intrusive_ptr concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self); +void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack); + class PyInterpreterHolder { public: PyInterpreterHolder() : impl_(new c10::impl::PyInterpreter( &concrete_name_fn, - &concrete_decref_fn)) {} + &concrete_decref_fn, + &concrete_detach_fn, + &concrete_dispatch_fn)) {} // NB: intentionally leaks the memory ~PyInterpreterHolder() { impl_->disarm(); @@ -112,6 +121,15 @@ static const char* VOLATILE_WARNING = "volatile was removed and now has no effect. Use " "`with torch.no_grad():` instead."; +static bool check_has_torch_dispatch(PyObject *obj) { + PyTypeObject *tp = Py_TYPE(obj); + return ( + !THPVariable_CheckTypeExact(tp) && + // TODO: test if Python key is disabled + PyObject_FastGetAttrString(obj, "__torch_dispatch__").ptr() != nullptr + ); +} + // Creates a new Python object for a Variable. The status parameter // specifies what the interpreter tag status on the object is; for // example, if you ran check_pyobj, the return optional of this object @@ -121,17 +139,19 @@ static const char* VOLATILE_WARNING = // It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED. static PyObject* THPVariable_NewWithVar( PyTypeObject* type, - Variable var, + Variable _var, c10::impl::PyInterpreterStatus status) { PyObject* obj = type->tp_alloc(type, 0); if (obj) { auto v = (THPVariable*) obj; // TODO: named constructor to avoid default initialization new (&v->cdata) MaybeOwned(); - v->cdata = MaybeOwned::owned(std::move(var)); - // cannot use var as it is moved out of - THPVariable_Unpack(v).unsafeGetTensorImpl()->init_pyobj( - self_interpreter.get(), obj, status); + v->cdata = MaybeOwned::owned(std::move(_var)); + const auto& var = THPVariable_Unpack(v); + var.unsafeGetTensorImpl()->init_pyobj(self_interpreter.get(), obj, status); + if (check_has_torch_dispatch(obj)) { + var.unsafeGetTensorImpl()->set_python_dispatch(true); + } } return obj; } @@ -338,10 +358,10 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P // rnn.flatten_parameters() // ``` data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true); - auto var = data.set_requires_grad(r.toBool(2)); + data.set_requires_grad(r.toBool(2)); return THPVariable_NewWithVar( (PyTypeObject*)cls, - std::move(var), + std::move(data), c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); END_HANDLE_TH_ERRORS } @@ -349,6 +369,14 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P typedef PyObject *(*getter)(PyObject *, void *); typedef int (*setter)(PyObject *, PyObject *, void *); +PyObject *THPVariable_get_python_dispatch(THPVariable *self, void *unused) +{ + HANDLE_TH_ERRORS + const auto& var = THPVariable_Unpack(self); + return torch::autograd::utils::wrap(var.unsafeGetTensorImpl()->is_python_dispatch()); + END_HANDLE_TH_ERRORS +} + PyObject *THPVariable_get_T(THPVariable *self, void *unused) { HANDLE_TH_ERRORS @@ -927,6 +955,7 @@ int THPVariable_set_imag(THPVariable* self, THPVariable *imag, void *unused) // manually. TODO: make declarable in native_functions // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) static struct PyGetSetDef THPVariable_properties[] = { + {"_python_dispatch", (getter)THPVariable_get_python_dispatch, nullptr, nullptr, nullptr}, {"T", (getter)THPVariable_get_T, nullptr, nullptr, nullptr}, {"_cdata", (getter)THPVariable_get_cdata, nullptr, nullptr, nullptr}, {"_version", (getter)THPVariable_get_version, nullptr, nullptr, nullptr}, @@ -1443,3 +1472,120 @@ bool THPVariable_initModule(PyObject *module) torch::autograd::initTensorImplConversion(module); return true; } + +namespace { + +bool isPythonTensor(const Tensor& tensor) { + return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); +} + +void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack) { + const auto& schema = op.schema(); + const auto num_returns = schema.returns().size(); + + const auto num_arguments = schema.arguments().size(); + auto arguments = torch::jit::pop(*stack, num_arguments); + + // Parse the name into namespace and name (no overload_name) + // TODO: put this into the library + const auto& qualified_name = op.operator_name().name; + auto pos = qualified_name.find("::"); + TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name); + // Make me some null terminated strings + std::string ns_str = qualified_name.substr(0, pos); + const char* ns = ns_str.c_str(); + const char* func_name = qualified_name.c_str() + pos + strlen("::"); + + // The plan: convert all the arguments back into PyObjects, + // extracting out the tensor handles, then call + // handle_torch_function_no_python_arg_parser + // NB: at the point arguments are pushed to the stack, ALL defaults + // are already present + + py::gil_scoped_acquire g; + + std::vector overloaded_args; + auto args = py::reinterpret_steal(PyTuple_New(num_arguments)); + // TODO: actually populate kwargs sometimes? At the moment, every argument + // just gets passed positionally + py::dict kwargs; + // For now, overloads get coalesced. Might be easier for users if they get + // overload resolution but is more complicated (need to expose separate + // functions per overload) + py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name); + std::string module_name_str = "torch.ops." + ns_str; + + for (int64_t idx = 0; idx < arguments.size(); idx++) { + auto& ivalue = arguments[idx]; + // Search for Tensors (as they may have the torch functions we need) + if (ivalue.isTensor()) { + const auto& tensor = ivalue.toTensor(); + if (isPythonTensor(tensor)) { + overloaded_args.emplace_back(py::cast(tensor)); + } + } else if (ivalue.isList()) { + const auto& list = ivalue.toListRef(); + for (int64_t jdx = 0; jdx < list.size(); jdx++) { + const auto& nv = list[jdx]; + if (nv.isTensor()) { + const auto& tensor = nv.toTensor(); + if (isPythonTensor(tensor)) { + overloaded_args.emplace_back(py::cast(tensor)); + } + } + } + } + PyTuple_SET_ITEM(args.ptr(), idx, torch::jit::toPyObject(std::move(ivalue)).release().ptr()); + } + + auto out = py::reinterpret_steal(handle_torch_function_no_python_arg_parser( + overloaded_args, + args.ptr(), + kwargs.ptr(), + func_name, + torch_api_function.ptr(), + module_name_str.c_str(), + "__torch_dispatch__" + )); + + if (op.schema().returns().size() == 1) { + torch::jit::push(stack, torch::jit::toIValue(out.ptr(), op.schema().returns()[0].type())); + } else { + auto outs = py::cast(out); + for (unsigned idx = 0; idx < outs.size(); idx++) { + torch::jit::push(stack, torch::jit::toIValue(outs[idx].ptr(), op.schema().returns()[idx].type())); + } + } +} + +c10::intrusive_ptr concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self) { + pybind11::gil_scoped_acquire gil; + + // Setup the arguments expected for the detach call + std::vector overloaded_args; + // TODO: there should be a shorter way to spell this + // TODO: fix the constness of target + Tensor self_t = Tensor(c10::intrusive_ptr::unsafe_reclaim_from_nonowning(const_cast(self))); + auto self_p = py::reinterpret_steal(THPVariable_Wrap(self_t)); + overloaded_args.emplace_back(self_p); + auto args = py::reinterpret_steal(PyTuple_New(1)); + PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr()); + + py::dict kwargs; + + auto out = py::reinterpret_steal(handle_torch_function_no_python_arg_parser( + overloaded_args, + args.ptr(), + kwargs.ptr(), + "detach", + py::module::import("torch").attr("ops").attr("aten").attr("detach").ptr(), + "torch.ops.aten", + "__torch_dispatch__" + )); + + TORCH_CHECK(THPVariable_Check(out.ptr()), "detach returned invalid type ", py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), ", expected Tensor"); + const Tensor& res_t = THPVariable_Unpack(out.ptr()); + return res_t.getIntrusivePtr(); +} + +} // anonymous namespace diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 4e12d358895..64f2dcea9b4 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -201,7 +201,7 @@ auto handle_torch_function(PyObject* self, const std::string& func_name, PyObjec return ret.release().ptr(); } -auto handle_torch_function_no_python_arg_parser(const std::vector &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name) -> PyObject* { +auto handle_torch_function_no_python_arg_parser(const std::vector &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, const char* torch_function_name) -> PyObject* { // overloaded_args already all have unique types std::vector overloaded_types; overloaded_types.reserve(overloaded_args.size()); @@ -212,7 +212,7 @@ auto handle_torch_function_no_python_arg_parser(const std::vector &o py::object ret; for (auto &arg : overloaded_args) { // NOLINTNEXTLINE(clang-diagnostic-writable-strings) - py::object torch_function = PyObject_FastGetAttrString(arg.ptr(), "__torch_function__"); + py::object torch_function = PyObject_FastGetAttrString(arg.ptr(), torch_function_name); ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function, py_types.ptr(), args, kwargs, NULL)); if (ret.ptr() != Py_NotImplemented) { // Return the reference to the result. This also covers the case where ret @@ -230,7 +230,7 @@ auto handle_torch_function_no_python_arg_parser(const std::vector &o // returned NotImplemented, so we raise a TypeError. std::stringstream ss; ss << "no implementation found for '" << module_name << "." << func_name - << "' on types that implement __torch_function__: ["; + << "' on types that implement " << torch_function_name << ": ["; for (auto &arg : overloaded_args) { ss << arg.ptr()->ob_type->tp_name; if (!arg.is(overloaded_args.back())) { diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 1f3a21b25c3..20b73865d46 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -767,7 +767,7 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args=nullptr, PyObject* kwargs=nullptr, PyObject* torch_api=THPVariableClass, const std::string& module_name="torch.Tensor") -> PyObject*; // Used for functions created in C++, e.g., C++ custom op, which doesn't use PythonArgParser to get overloaded_args. -auto handle_torch_function_no_python_arg_parser(const std::vector &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name) -> PyObject*; +auto handle_torch_function_no_python_arg_parser(const std::vector &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, const char* torch_function_name = "__torch_function__") -> PyObject*; // Used for getters of Tensor properties auto handle_torch_function_getter(THPVariable* self, const std::string& property_name) -> PyObject*; diff --git a/torch/overrides.py b/torch/overrides.py index a617c8ec91b..63a93f6e7dd 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -228,6 +228,7 @@ def get_ignored_functions() -> Set[Callable]: Tensor.to_sparse_csr, Tensor._reduce_ex_internal, Tensor._fix_weakref, + Tensor._python_dispatch.__get__, Tensor._conj, Tensor._conj_physical, }