mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Dispatch to Python via __torch_dispatch__ (#59760)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760 See https://github.com/pytorch/pytorch/issues/59049 There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts. **The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes. **Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with then newly added `check_has_torch_dispatch`. **Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl. **torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python. **Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly. **Known limitations.** * We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way) * `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.) * We don't ever populate kwargs, even when an argument is kwarg-only Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D29017912 D29017912 Test Plan: Imported from OSS Reviewed By: bdhirsh Pulled By: ezyang fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
This commit is contained in:
parent
a53d7f8f7c
commit
aacc722aec
40
aten/src/ATen/core/PythonFallbackKernel.cpp
Normal file
40
aten/src/ATen/core/PythonFallbackKernel.cpp
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
#include <torch/library.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
namespace {
|
||||
|
||||
void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
// It is safe to dispatch on the very first Tensor with a pyobj_interpreter
|
||||
// without checking the interpreters of any of the arguments, because when
|
||||
// we actually run dispatch(), we will take out PyObjects in the context
|
||||
// of that interpreter, and this will ensure that everyone is on the same
|
||||
// interpreter.
|
||||
for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
|
||||
if (ivalue.isTensor()) {
|
||||
auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter();
|
||||
if (interpreter) {
|
||||
interpreter->dispatch(op, stack);
|
||||
return;
|
||||
}
|
||||
} else if (ivalue.isTensorList()) {
|
||||
// NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef
|
||||
// is not a thing)
|
||||
for (const auto& nv : ivalue.toListRef()) {
|
||||
auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter();
|
||||
if (interpreter) {
|
||||
interpreter->dispatch(op, stack);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)");
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, Python, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallback>());
|
||||
}
|
||||
|
|
@ -61,6 +61,9 @@ const char* toString(DispatchKey t) {
|
|||
case DispatchKey::NestedTensor:
|
||||
return "NestedTensor";
|
||||
|
||||
case DispatchKey::Python:
|
||||
return "Python";
|
||||
|
||||
case DispatchKey::PrivateUse1:
|
||||
return "PrivateUse1";
|
||||
case DispatchKey::PrivateUse2:
|
||||
|
|
|
|||
|
|
@ -110,6 +110,7 @@ enum class DispatchKey : uint8_t {
|
|||
SparseCsrCUDA,
|
||||
|
||||
NestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
|
||||
|
||||
// Here are reserved backends for user-defined backends, see Note [Private use
|
||||
// DispatchKey]
|
||||
// To see some example about how to use this, check out MSNPU
|
||||
|
|
@ -119,6 +120,8 @@ enum class DispatchKey : uint8_t {
|
|||
|
||||
// Define an alias key to represent end of backend dispatch keys.
|
||||
// If you add new backend keys after PrivateUse3, please also update it here.
|
||||
// (But you shouldn't: private use keys should have higher precedence than
|
||||
// all built-in keys)
|
||||
EndOfBackendKeys = PrivateUse3,
|
||||
|
||||
// In some situations, it is not immediately obvious what the correct
|
||||
|
|
@ -128,6 +131,7 @@ enum class DispatchKey : uint8_t {
|
|||
// correct backend.
|
||||
BackendSelect,
|
||||
|
||||
Python,
|
||||
FuncTorchPython, // See Note [Out-of-tree vmap+grad prototype]
|
||||
|
||||
// The named dispatch key is set for any tensors with named dimensions.
|
||||
|
|
|
|||
|
|
@ -31,9 +31,28 @@ static void noop_decref_fn(const PyInterpreter*, PyObject*) {
|
|||
// no-op
|
||||
}
|
||||
|
||||
static c10::intrusive_ptr<TensorImpl> noop_detach_fn(
|
||||
const PyInterpreter*,
|
||||
const TensorImpl*) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
0,
|
||||
"attempted to detach (shallow_copy_and_detach) Tensor with nontrivial PyObject after corresponding interpreter died");
|
||||
}
|
||||
|
||||
static void noop_dispatch_fn(
|
||||
const PyInterpreter*,
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
0,
|
||||
"attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died");
|
||||
}
|
||||
|
||||
void PyInterpreter::disarm() noexcept {
|
||||
name_fn_ = &noop_name_fn;
|
||||
decref_fn_ = &noop_decref_fn;
|
||||
detach_fn_ = &noop_detach_fn;
|
||||
dispatch_fn_ = &noop_dispatch_fn;
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
|
@ -95,6 +114,23 @@ TensorImpl::TensorImpl(
|
|||
data_type,
|
||||
storage.device()) {}
|
||||
|
||||
// [Note: Python key removal]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// In most constructors for TensorImpl, you will see Python key is removed from
|
||||
// the passed in DispatchKeySet. Why?
|
||||
//
|
||||
// INVARIANT: Python dispatch key is set iff PyObject for the Tensor has a
|
||||
// nontrivial __torch_dispatch__ implementation.
|
||||
//
|
||||
// When a fresh TensorImpl is created, there is *no* PyObject (this only gets
|
||||
// initialized lazily at the first point in time the Tensor passes into Python).
|
||||
// So we would violate the invariant.
|
||||
//
|
||||
// In practice, what will happen shortly afterwards is that the TensorImpl
|
||||
// will get its PyObject initialized by Tensor._make_subclass; at this point
|
||||
// the Python dispatch key will be set and all is well. The point is to delay
|
||||
// the dispatch key setting until that point.
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
TensorImpl::TensorImpl(
|
||||
ImplType type,
|
||||
|
|
@ -108,7 +144,8 @@ TensorImpl::TensorImpl(
|
|||
numel_(0),
|
||||
data_type_(data_type),
|
||||
device_opt_(storage_.device()),
|
||||
key_set_(key_set) {
|
||||
key_set_(key_set.remove(
|
||||
DispatchKey::Python)) { // See [Note: Python key removal]
|
||||
init_bitfields();
|
||||
// Inference tensor doesn't have version counter.
|
||||
if (!is_inference()) {
|
||||
|
|
@ -153,6 +190,9 @@ TensorImpl::TensorImpl(
|
|||
|
||||
key_set = key_set | getAutocastRelatedKeySetFromBackend(k);
|
||||
|
||||
key_set =
|
||||
key_set.remove(DispatchKey::Python); // See [Note: Python key removal]
|
||||
|
||||
// Inference tensor doesn't have autograd related keys.
|
||||
if (inference_mode) {
|
||||
// See Note [Expected TLS state in InferenceMode] for why we exclude
|
||||
|
|
@ -459,6 +499,17 @@ c10::AutogradMetaInterface* TensorImpl::autograd_meta() const {
|
|||
c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const {
|
||||
if (key_set_.has(DispatchKey::Python) &&
|
||||
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
|
||||
auto r = pyobj_interpreter_.load(std::memory_order_acquire)->detach(this);
|
||||
if (r) {
|
||||
r->set_version_counter(version_counter);
|
||||
r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
return r;
|
||||
}
|
||||
// otherwise just copy the TensorImpl and not the PyObject. Since
|
||||
// the interpreter is dead no one can call us out on it
|
||||
}
|
||||
auto impl = c10::make_intrusive<TensorImpl>(
|
||||
// No need to populate Storage; copy_tensor_metadata will do it for us.
|
||||
key_set_,
|
||||
|
|
@ -477,6 +528,17 @@ c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
|
|||
c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
|
||||
c10::VariableVersion&& version_counter,
|
||||
bool allow_tensor_metadata_change) const {
|
||||
if (key_set_.has(DispatchKey::Python) &&
|
||||
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
|
||||
auto r = pyobj_interpreter_.load(std::memory_order_acquire)->detach(this);
|
||||
if (r) {
|
||||
r->set_version_counter(std::move(version_counter));
|
||||
r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
return r;
|
||||
}
|
||||
// otherwise just copy the TensorImpl and not the PyObject. Since
|
||||
// the interpreter is dead no one can call us out on it
|
||||
}
|
||||
auto impl = c10::make_intrusive<TensorImpl>(
|
||||
// No need to populate Storage; copy_tensor_metadata will do it for us.
|
||||
key_set_,
|
||||
|
|
@ -501,7 +563,7 @@ void TensorImpl::copy_tensor_metadata_except_version_counter(
|
|||
dest_impl->storage_offset_ = src_impl->storage_offset_;
|
||||
dest_impl->data_type_ = src_impl->data_type_;
|
||||
dest_impl->device_opt_ = src_impl->device_opt_;
|
||||
dest_impl->key_set_ = src_impl->key_set_;
|
||||
dest_impl->key_set_ = src_impl->key_set_.remove(DispatchKey::Python);
|
||||
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
|
||||
dest_impl->has_contiguity_ = src_impl->has_contiguity_;
|
||||
dest_impl->is_channels_last_contiguous_ =
|
||||
|
|
|
|||
|
|
@ -42,7 +42,18 @@ class Tensor;
|
|||
|
||||
namespace c10 {
|
||||
class Scalar;
|
||||
struct IValue;
|
||||
struct Storage;
|
||||
class OperatorHandle;
|
||||
} // namespace c10
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using Stack = std::vector<c10::IValue>;
|
||||
}
|
||||
} // namespace torch
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/**
|
||||
* A utility function to convert vector<int> to vector<int64_t>.
|
||||
|
|
@ -239,14 +250,27 @@ struct PyInterpreter;
|
|||
struct C10_API PyInterpreter {
|
||||
using name_sig = std::string(const PyInterpreter*);
|
||||
using decref_sig = void(const PyInterpreter*, PyObject*);
|
||||
using detach_sig =
|
||||
c10::intrusive_ptr<TensorImpl>(const PyInterpreter*, const TensorImpl*);
|
||||
using dispatch_sig = void(
|
||||
const PyInterpreter*,
|
||||
const c10::OperatorHandle&,
|
||||
torch::jit::Stack* stack);
|
||||
|
||||
PyInterpreter(name_sig* name_fn, decref_sig* decref_fn)
|
||||
: name_fn_(name_fn), decref_fn_(decref_fn) {}
|
||||
PyInterpreter(
|
||||
name_sig* name_fn,
|
||||
decref_sig* decref_fn,
|
||||
detach_sig* detach,
|
||||
dispatch_sig* dispatch)
|
||||
: name_fn_(name_fn),
|
||||
decref_fn_(decref_fn),
|
||||
detach_fn_(detach),
|
||||
dispatch_fn_(dispatch) {}
|
||||
|
||||
// For debugging purposes only
|
||||
name_sig* name_fn_;
|
||||
|
||||
decref_sig* decref_fn_;
|
||||
detach_sig* detach_fn_;
|
||||
dispatch_sig* dispatch_fn_;
|
||||
|
||||
// UBSAN suppression fixes: "call to function
|
||||
// (anonymous namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*,
|
||||
|
|
@ -254,6 +278,7 @@ struct C10_API PyInterpreter {
|
|||
// c10::impl::PyInterpreter *, _object *)'" See
|
||||
// https://github.com/google/sanitizers/issues/911
|
||||
|
||||
// Report the name of this interpreter
|
||||
__ubsan_ignore_function__ std::string name() const {
|
||||
return (*name_fn_)(this);
|
||||
}
|
||||
|
|
@ -263,6 +288,21 @@ struct C10_API PyInterpreter {
|
|||
return (*decref_fn_)(this, pyobj);
|
||||
}
|
||||
|
||||
// Perform a detach by deferring to the __torch_dispatch__ implementation of
|
||||
// detach, which will also arrange for the PyObject to get copied in this
|
||||
// situation
|
||||
__ubsan_ignore_function__ c10::intrusive_ptr<TensorImpl> detach(
|
||||
const TensorImpl* self) const {
|
||||
return (*detach_fn_)(this, self);
|
||||
}
|
||||
|
||||
// Invoke the Python boxed fallback dispatch to go back into Python
|
||||
__ubsan_ignore_function__ void dispatch(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack) const {
|
||||
return (*dispatch_fn_)(this, op, stack);
|
||||
}
|
||||
|
||||
// Disarm this PyInterpreter, making all of its methods noops.
|
||||
// Because the function pointers are raw pointers (not atomics),
|
||||
// a disarm() invocation that is concurrent with active destructors
|
||||
|
|
@ -1321,6 +1361,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
}
|
||||
}
|
||||
|
||||
void set_python_dispatch(bool k) {
|
||||
if (k) {
|
||||
key_set_ = key_set_.add(DispatchKey::Python);
|
||||
} else {
|
||||
key_set_ = key_set_.remove(DispatchKey::Python);
|
||||
}
|
||||
}
|
||||
|
||||
bool is_python_dispatch() const {
|
||||
return key_set_.has(DispatchKey::Python);
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the pointer to named tensor metadata.
|
||||
*/
|
||||
|
|
@ -1510,6 +1562,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
pyobj_ = pyobj;
|
||||
}
|
||||
|
||||
// Query the PyObject interpreter. This may return null if there is no
|
||||
// interpreter. This is racy!
|
||||
impl::PyInterpreter* pyobj_interpreter() {
|
||||
return pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
// Test the interpreter tag. If tagged for the current interpreter, return
|
||||
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
|
||||
// returns a nullopt. If it is definitely invalid, raises an error.
|
||||
|
|
|
|||
257
test/test_python_dispatch.py
Normal file
257
test/test_python_dispatch.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from typing import Iterator, List
|
||||
import logging
|
||||
import contextlib
|
||||
|
||||
# TODO: move this into library proper
|
||||
@contextlib.contextmanager
|
||||
def no_dispatch() -> Iterator[None]:
|
||||
guard = torch._C._DisableTorchDispatch()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del guard
|
||||
|
||||
|
||||
# How the chain of calls works for LoggingTensor:
|
||||
# 1. Call torch.sin
|
||||
# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely
|
||||
# 3. Enter dispatcher, wind your way through Autograd
|
||||
# 4. Hit Python dispatch key, call __torch_dispatch__
|
||||
|
||||
# TODO: TensorBase should work
|
||||
class LoggingTensor(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
|
||||
__slots__ = ['elem']
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
# The wrapping tensor (LoggingTensor) is just a meta tensor, so it
|
||||
# doesn't hold any memory (meta tensor is generally the preferred type
|
||||
# of tensor you want to make a subclass from)...
|
||||
r = torch.Tensor._make_subclass(cls, elem.to('meta'), elem.requires_grad)
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
r.elem = elem
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
return f"LoggingTensor({self.elem})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def unwrap(e):
|
||||
return e.elem if isinstance(e, LoggingTensor) else e
|
||||
|
||||
def wrap(e):
|
||||
return LoggingTensor(e) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
# TODO: handle kwargs
|
||||
assert not kwargs
|
||||
rs = tree_map(wrap, func(*tree_map(unwrap, args)))
|
||||
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, rs)
|
||||
return rs
|
||||
|
||||
# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
|
||||
class LoggingTensorHandler(logging.Handler):
|
||||
log_list: List[str]
|
||||
next_shortid: int
|
||||
|
||||
def __init__(self, log_list: List[str]) -> None:
|
||||
logging.Handler.__init__(self)
|
||||
self.log_list = log_list
|
||||
self.next_shortid = 0
|
||||
|
||||
# WARNING: not deterministic over multiple threads, this matters for
|
||||
# autograd
|
||||
def _shortid(self, o: object) -> int:
|
||||
if not hasattr(o, '_shortid'):
|
||||
o._shortid = self.next_shortid
|
||||
self.next_shortid += 1
|
||||
return o._shortid
|
||||
|
||||
def _fmt(self, a: object) -> str:
|
||||
return f'${self._shortid(a)}' if isinstance(a, LoggingTensor) else repr(a)
|
||||
|
||||
def emit(self, record):
|
||||
fmt_args = "(" + ", ".join(self._fmt(a) for a in record.args[0]) + ")"
|
||||
fmt_rets = ", ".join(self._fmt(a) for a in record.args[1]) \
|
||||
if isinstance(record.args[1], (list, tuple)) else self._fmt(record.args[1])
|
||||
self.log_list.append(f'{fmt_rets} = {record.msg}{fmt_args}')
|
||||
|
||||
def log_input(name: str, var: object):
|
||||
logging.getLogger("LoggingTensor").info("input", (name,), (var,))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_logs() -> Iterator[List[str]]:
|
||||
logger = logging.getLogger("LoggingTensor")
|
||||
log_list = []
|
||||
handler = LoggingTensorHandler(log_list)
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
try:
|
||||
yield log_list
|
||||
finally:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
class TestPythonDispatch(TestCase):
|
||||
def test_basic(self) -> None:
|
||||
with capture_logs() as logs:
|
||||
x = LoggingTensor(torch.tensor([3.0], requires_grad=True))
|
||||
log_input("x", x)
|
||||
y = x * x
|
||||
saved_x = y.grad_fn._saved_self
|
||||
grad_y = LoggingTensor(torch.tensor([1.0]))
|
||||
log_input("grad_y", grad_y)
|
||||
g, = torch.autograd.grad((y,), (x,), (grad_y,))
|
||||
|
||||
self.assertEqual(g.elem, torch.tensor([6.0]))
|
||||
with torch.no_grad():
|
||||
self.assertEqual(saved_x, x)
|
||||
self.assertEqual(saved_x._version, x._version)
|
||||
x.add_(2)
|
||||
self.assertEqual(saved_x, x)
|
||||
# TODO: figure out why broken
|
||||
# self.assertEqual(saved_x._version, x._version)
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = torch._ops.aten.mul($0, $0)
|
||||
$2 = input('grad_y')
|
||||
$3 = torch._ops.aten.mul($2, $0)
|
||||
$4 = torch._ops.aten.mul($2, $0)
|
||||
$5 = torch._ops.aten.add($4, $3, 1)''')
|
||||
|
||||
def test_out(self) -> None:
|
||||
with capture_logs() as logs:
|
||||
x = LoggingTensor(torch.ones(1))
|
||||
y = LoggingTensor(torch.zeros(1))
|
||||
log_input("x", x)
|
||||
log_input("y", y)
|
||||
torch.abs(x, out=y)
|
||||
|
||||
self.assertEqual(y.elem, torch.ones(1))
|
||||
# TODO: arguably this shouldn't pass and we should complain
|
||||
# that out isn't a kwarg
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = input('y')
|
||||
$2 = torch._ops.aten.abs($0, $1)''')
|
||||
|
||||
def test_list_ret(self) -> None:
|
||||
# test all sequence types are permissible returns
|
||||
for list_type in (list, tuple):
|
||||
class A(torch._C._TensorBase):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
if func == torch.ops.aten.split:
|
||||
with no_dispatch():
|
||||
return list_type(torch.split(*args))
|
||||
else:
|
||||
raise AssertionError(f"unrecognized func: {func}")
|
||||
|
||||
self.assertEqual(
|
||||
torch.split(A(torch.tensor([0, 1])), 2),
|
||||
torch.split(torch.tensor([0, 1]), 2)
|
||||
)
|
||||
|
||||
def test_invalid_ret(self) -> None:
|
||||
# test invalid return gets reasonable error message
|
||||
class A(torch._C._TensorBase):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
return "arf"
|
||||
|
||||
self.assertExpectedRaisesInline(
|
||||
RuntimeError, lambda: A(torch.zeros(1)).neg(),
|
||||
"""Unable to cast Python instance of type <class 'str'> to C++ type 'at::Tensor'"""
|
||||
)
|
||||
self.assertExpectedRaisesInline(
|
||||
RuntimeError, lambda: A(torch.zeros(1)).detach(),
|
||||
"""detach returned invalid type str, expected Tensor"""
|
||||
)
|
||||
|
||||
def test_metadata_change_not_allowed(self) -> None:
|
||||
x = LoggingTensor(torch.ones(1))
|
||||
y = x.data
|
||||
self.assertIsInstance(y, LoggingTensor)
|
||||
self.assertRaises(RuntimeError, lambda: y.resize_(4))
|
||||
|
||||
def test_version(self) -> None:
|
||||
x = LoggingTensor(torch.ones(1))
|
||||
prev_vc = x._version
|
||||
x.detach().add_(2)
|
||||
cur_vc = x._version
|
||||
self.assertNotEqual(prev_vc, cur_vc)
|
||||
x.data.add_(2)
|
||||
self.assertEqual(cur_vc, x._version)
|
||||
|
||||
def test_format(self) -> None:
|
||||
x = LoggingTensor(torch.ones(1))
|
||||
s1 = str(x)
|
||||
s2 = repr(x)
|
||||
s3 = f"{x}"
|
||||
self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""")
|
||||
self.assertEqual(s1, s2)
|
||||
self.assertEqual(s1, s3)
|
||||
|
||||
def test_custom_autograd(self) -> None:
|
||||
escape = [None]
|
||||
|
||||
class Square(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
y = x ** 2
|
||||
ctx.save_for_backward(x)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
assert isinstance(grad_output, LoggingTensor)
|
||||
x, = ctx.saved_tensors
|
||||
assert isinstance(x, LoggingTensor)
|
||||
escape[0] = x
|
||||
return grad_output * 2 * x
|
||||
|
||||
with capture_logs() as logs:
|
||||
x = LoggingTensor(torch.ones(1, requires_grad=True))
|
||||
log_input("x", x)
|
||||
x.grad = LoggingTensor(torch.zeros(1))
|
||||
log_input("x.grad", x.grad)
|
||||
y = Square.apply(x)
|
||||
grad_output = LoggingTensor(torch.ones(1))
|
||||
log_input("grad_output", grad_output)
|
||||
y.backward(grad_output)
|
||||
|
||||
with torch.no_grad():
|
||||
self.assertEqual(escape[0], x)
|
||||
self.assertEqual(escape[0]._version, x._version)
|
||||
# TODO: figure out why x.requires_grad = False doesn't
|
||||
# trigger an error for LoggingTensor
|
||||
x.add_(2)
|
||||
self.assertEqual(escape[0], x)
|
||||
# TODO: figure out why this is broken
|
||||
# self.assertEqual(escape[0]._version, x._version)
|
||||
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = input('x.grad')
|
||||
$2 = torch._ops.aten.pow($0, 2)
|
||||
$3 = input('grad_output')
|
||||
$4 = torch._ops.aten.mul($3, tensor(2))
|
||||
$5 = torch._ops.aten.mul($4, $0)
|
||||
$6 = torch._ops.aten.add_($1, $5, 1)''')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -1026,7 +1026,7 @@ def _convert(ret, cls):
|
|||
if cls is Tensor:
|
||||
return ret
|
||||
|
||||
if isinstance(ret, Tensor):
|
||||
if isinstance(ret, Tensor) and not isinstance(ret, cls):
|
||||
ret = ret.as_subclass(cls)
|
||||
|
||||
if isinstance(ret, (tuple, list)):
|
||||
|
|
|
|||
|
|
@ -17,6 +17,12 @@
|
|||
|
||||
#include <set>
|
||||
|
||||
struct DisableTorchDispatch {
|
||||
DisableTorchDispatch() : guard_(c10::DispatchKey::Python) {
|
||||
}
|
||||
c10::impl::ExcludeDispatchKeyGuard guard_;
|
||||
};
|
||||
|
||||
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
||||
using namespace torch::autograd::profiler;
|
||||
auto tensor_module = THPObjectPtr(PyImport_ImportModule("torch._tensor"));
|
||||
|
|
@ -254,6 +260,9 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
|||
py::class_<c10::InferenceMode>(_C_m, "_InferenceMode")
|
||||
.def(py::init<bool>());
|
||||
|
||||
py::class_<DisableTorchDispatch>(_C_m, "_DisableTorchDispatch")
|
||||
.def(py::init<>());
|
||||
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,10 @@
|
|||
#include <c10/util/DeadlockDetection.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
|
|
@ -77,12 +81,17 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj) {
|
|||
Py_DECREF(pyobj);
|
||||
};
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
|
||||
void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
||||
|
||||
class PyInterpreterHolder {
|
||||
public:
|
||||
PyInterpreterHolder()
|
||||
: impl_(new c10::impl::PyInterpreter(
|
||||
&concrete_name_fn,
|
||||
&concrete_decref_fn)) {}
|
||||
&concrete_decref_fn,
|
||||
&concrete_detach_fn,
|
||||
&concrete_dispatch_fn)) {}
|
||||
// NB: intentionally leaks the memory
|
||||
~PyInterpreterHolder() {
|
||||
impl_->disarm();
|
||||
|
|
@ -112,6 +121,15 @@ static const char* VOLATILE_WARNING =
|
|||
"volatile was removed and now has no effect. Use "
|
||||
"`with torch.no_grad():` instead.";
|
||||
|
||||
static bool check_has_torch_dispatch(PyObject *obj) {
|
||||
PyTypeObject *tp = Py_TYPE(obj);
|
||||
return (
|
||||
!THPVariable_CheckTypeExact(tp) &&
|
||||
// TODO: test if Python key is disabled
|
||||
PyObject_FastGetAttrString(obj, "__torch_dispatch__").ptr() != nullptr
|
||||
);
|
||||
}
|
||||
|
||||
// Creates a new Python object for a Variable. The status parameter
|
||||
// specifies what the interpreter tag status on the object is; for
|
||||
// example, if you ran check_pyobj, the return optional of this object
|
||||
|
|
@ -121,17 +139,19 @@ static const char* VOLATILE_WARNING =
|
|||
// It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED.
|
||||
static PyObject* THPVariable_NewWithVar(
|
||||
PyTypeObject* type,
|
||||
Variable var,
|
||||
Variable _var,
|
||||
c10::impl::PyInterpreterStatus status) {
|
||||
PyObject* obj = type->tp_alloc(type, 0);
|
||||
if (obj) {
|
||||
auto v = (THPVariable*) obj;
|
||||
// TODO: named constructor to avoid default initialization
|
||||
new (&v->cdata) MaybeOwned<Variable>();
|
||||
v->cdata = MaybeOwned<Variable>::owned(std::move(var));
|
||||
// cannot use var as it is moved out of
|
||||
THPVariable_Unpack(v).unsafeGetTensorImpl()->init_pyobj(
|
||||
self_interpreter.get(), obj, status);
|
||||
v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
|
||||
const auto& var = THPVariable_Unpack(v);
|
||||
var.unsafeGetTensorImpl()->init_pyobj(self_interpreter.get(), obj, status);
|
||||
if (check_has_torch_dispatch(obj)) {
|
||||
var.unsafeGetTensorImpl()->set_python_dispatch(true);
|
||||
}
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
|
@ -338,10 +358,10 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P
|
|||
// rnn.flatten_parameters()
|
||||
// ```
|
||||
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
||||
auto var = data.set_requires_grad(r.toBool(2));
|
||||
data.set_requires_grad(r.toBool(2));
|
||||
return THPVariable_NewWithVar(
|
||||
(PyTypeObject*)cls,
|
||||
std::move(var),
|
||||
std::move(data),
|
||||
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
|
@ -349,6 +369,14 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P
|
|||
typedef PyObject *(*getter)(PyObject *, void *);
|
||||
typedef int (*setter)(PyObject *, PyObject *, void *);
|
||||
|
||||
PyObject *THPVariable_get_python_dispatch(THPVariable *self, void *unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
const auto& var = THPVariable_Unpack(self);
|
||||
return torch::autograd::utils::wrap(var.unsafeGetTensorImpl()->is_python_dispatch());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject *THPVariable_get_T(THPVariable *self, void *unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
|
|
@ -927,6 +955,7 @@ int THPVariable_set_imag(THPVariable* self, THPVariable *imag, void *unused)
|
|||
// manually. TODO: make declarable in native_functions
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static struct PyGetSetDef THPVariable_properties[] = {
|
||||
{"_python_dispatch", (getter)THPVariable_get_python_dispatch, nullptr, nullptr, nullptr},
|
||||
{"T", (getter)THPVariable_get_T, nullptr, nullptr, nullptr},
|
||||
{"_cdata", (getter)THPVariable_get_cdata, nullptr, nullptr, nullptr},
|
||||
{"_version", (getter)THPVariable_get_version, nullptr, nullptr, nullptr},
|
||||
|
|
@ -1443,3 +1472,120 @@ bool THPVariable_initModule(PyObject *module)
|
|||
torch::autograd::initTensorImplConversion(module);
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool isPythonTensor(const Tensor& tensor) {
|
||||
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
|
||||
}
|
||||
|
||||
void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
auto arguments = torch::jit::pop(*stack, num_arguments);
|
||||
|
||||
// Parse the name into namespace and name (no overload_name)
|
||||
// TODO: put this into the library
|
||||
const auto& qualified_name = op.operator_name().name;
|
||||
auto pos = qualified_name.find("::");
|
||||
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
|
||||
// Make me some null terminated strings
|
||||
std::string ns_str = qualified_name.substr(0, pos);
|
||||
const char* ns = ns_str.c_str();
|
||||
const char* func_name = qualified_name.c_str() + pos + strlen("::");
|
||||
|
||||
// The plan: convert all the arguments back into PyObjects,
|
||||
// extracting out the tensor handles, then call
|
||||
// handle_torch_function_no_python_arg_parser
|
||||
// NB: at the point arguments are pushed to the stack, ALL defaults
|
||||
// are already present
|
||||
|
||||
py::gil_scoped_acquire g;
|
||||
|
||||
std::vector<py::handle> overloaded_args;
|
||||
auto args = py::reinterpret_steal<py::object>(PyTuple_New(num_arguments));
|
||||
// TODO: actually populate kwargs sometimes? At the moment, every argument
|
||||
// just gets passed positionally
|
||||
py::dict kwargs;
|
||||
// For now, overloads get coalesced. Might be easier for users if they get
|
||||
// overload resolution but is more complicated (need to expose separate
|
||||
// functions per overload)
|
||||
py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
||||
std::string module_name_str = "torch.ops." + ns_str;
|
||||
|
||||
for (int64_t idx = 0; idx < arguments.size(); idx++) {
|
||||
auto& ivalue = arguments[idx];
|
||||
// Search for Tensors (as they may have the torch functions we need)
|
||||
if (ivalue.isTensor()) {
|
||||
const auto& tensor = ivalue.toTensor();
|
||||
if (isPythonTensor(tensor)) {
|
||||
overloaded_args.emplace_back(py::cast(tensor));
|
||||
}
|
||||
} else if (ivalue.isList()) {
|
||||
const auto& list = ivalue.toListRef();
|
||||
for (int64_t jdx = 0; jdx < list.size(); jdx++) {
|
||||
const auto& nv = list[jdx];
|
||||
if (nv.isTensor()) {
|
||||
const auto& tensor = nv.toTensor();
|
||||
if (isPythonTensor(tensor)) {
|
||||
overloaded_args.emplace_back(py::cast(tensor));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
PyTuple_SET_ITEM(args.ptr(), idx, torch::jit::toPyObject(std::move(ivalue)).release().ptr());
|
||||
}
|
||||
|
||||
auto out = py::reinterpret_steal<py::object>(handle_torch_function_no_python_arg_parser(
|
||||
overloaded_args,
|
||||
args.ptr(),
|
||||
kwargs.ptr(),
|
||||
func_name,
|
||||
torch_api_function.ptr(),
|
||||
module_name_str.c_str(),
|
||||
"__torch_dispatch__"
|
||||
));
|
||||
|
||||
if (op.schema().returns().size() == 1) {
|
||||
torch::jit::push(stack, torch::jit::toIValue(out.ptr(), op.schema().returns()[0].type()));
|
||||
} else {
|
||||
auto outs = py::cast<py::sequence>(out);
|
||||
for (unsigned idx = 0; idx < outs.size(); idx++) {
|
||||
torch::jit::push(stack, torch::jit::toIValue(outs[idx].ptr(), op.schema().returns()[idx].type()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self) {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
|
||||
// Setup the arguments expected for the detach call
|
||||
std::vector<py::handle> overloaded_args;
|
||||
// TODO: there should be a shorter way to spell this
|
||||
// TODO: fix the constness of target
|
||||
Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
|
||||
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
|
||||
overloaded_args.emplace_back(self_p);
|
||||
auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
|
||||
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
|
||||
|
||||
py::dict kwargs;
|
||||
|
||||
auto out = py::reinterpret_steal<py::object>(handle_torch_function_no_python_arg_parser(
|
||||
overloaded_args,
|
||||
args.ptr(),
|
||||
kwargs.ptr(),
|
||||
"detach",
|
||||
py::module::import("torch").attr("ops").attr("aten").attr("detach").ptr(),
|
||||
"torch.ops.aten",
|
||||
"__torch_dispatch__"
|
||||
));
|
||||
|
||||
TORCH_CHECK(THPVariable_Check(out.ptr()), "detach returned invalid type ", py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), ", expected Tensor");
|
||||
const Tensor& res_t = THPVariable_Unpack(out.ptr());
|
||||
return res_t.getIntrusivePtr();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ auto handle_torch_function(PyObject* self, const std::string& func_name, PyObjec
|
|||
return ret.release().ptr();
|
||||
}
|
||||
|
||||
auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name) -> PyObject* {
|
||||
auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, const char* torch_function_name) -> PyObject* {
|
||||
// overloaded_args already all have unique types
|
||||
std::vector<py::object> overloaded_types;
|
||||
overloaded_types.reserve(overloaded_args.size());
|
||||
|
|
@ -212,7 +212,7 @@ auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &o
|
|||
py::object ret;
|
||||
for (auto &arg : overloaded_args) {
|
||||
// NOLINTNEXTLINE(clang-diagnostic-writable-strings)
|
||||
py::object torch_function = PyObject_FastGetAttrString(arg.ptr(), "__torch_function__");
|
||||
py::object torch_function = PyObject_FastGetAttrString(arg.ptr(), torch_function_name);
|
||||
ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function, py_types.ptr(), args, kwargs, NULL));
|
||||
if (ret.ptr() != Py_NotImplemented) {
|
||||
// Return the reference to the result. This also covers the case where ret
|
||||
|
|
@ -230,7 +230,7 @@ auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &o
|
|||
// returned NotImplemented, so we raise a TypeError.
|
||||
std::stringstream ss;
|
||||
ss << "no implementation found for '" << module_name << "." << func_name
|
||||
<< "' on types that implement __torch_function__: [";
|
||||
<< "' on types that implement " << torch_function_name << ": [";
|
||||
for (auto &arg : overloaded_args) {
|
||||
ss << arg.ptr()->ob_type->tp_name;
|
||||
if (!arg.is(overloaded_args.back())) {
|
||||
|
|
|
|||
|
|
@ -767,7 +767,7 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb
|
|||
auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args=nullptr, PyObject* kwargs=nullptr, PyObject* torch_api=THPVariableClass, const std::string& module_name="torch.Tensor") -> PyObject*;
|
||||
|
||||
// Used for functions created in C++, e.g., C++ custom op, which doesn't use PythonArgParser to get overloaded_args.
|
||||
auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name) -> PyObject*;
|
||||
auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, const char* torch_function_name = "__torch_function__") -> PyObject*;
|
||||
|
||||
// Used for getters of Tensor properties
|
||||
auto handle_torch_function_getter(THPVariable* self, const std::string& property_name) -> PyObject*;
|
||||
|
|
|
|||
|
|
@ -228,6 +228,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
Tensor.to_sparse_csr,
|
||||
Tensor._reduce_ex_internal,
|
||||
Tensor._fix_weakref,
|
||||
Tensor._python_dispatch.__get__,
|
||||
Tensor._conj,
|
||||
Tensor._conj_physical,
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user