Fix wrapper subclass reentrant dispatch + TorchDispatchMode (#136566)

Fixes #136565

This PR makes the python fallback robust to the case where there are no active modes & no tensors with the Python key. In this case, simply redispatch with the Python key disabled.

This was found when trying to use reentrant dispatch for NJT to get decompositions under `inference_mode()` when the autograd key is disabled.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136566
Approved by: https://github.com/bdhirsh
This commit is contained in:
Joel Schlosser 2024-09-25 15:52:52 -04:00 committed by PyTorch MergeBot
parent 963e793e1b
commit f8debd5d83
2 changed files with 68 additions and 9 deletions

View File

@ -45,7 +45,7 @@ private:
c10::impl::LocalDispatchKeySet saved_;
};
void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
void pythonFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
// c10::impl::ForceDispatchKeyGuard dispatcher_guard(tls_on_entry.value());
// StashTLSOnEntryGuard stash_guard;
@ -68,12 +68,20 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
// we actually run dispatch(), we will take out PyObjects in the context
// of that interpreter, and this will ensure that everyone is on the same
// interpreter.
bool tensors_with_python_key_present = false;
c10::impl::PyInterpreter* interpreter = nullptr;
for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
if (ivalue.isTensor()) {
auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
if (interpreter) {
(*interpreter)->dispatch(op, stack);
return;
auto* t = ivalue.unsafeToTensorImpl();
if (t->key_set().has(c10::DispatchKey::Python)) {
tensors_with_python_key_present = true;
}
if (!interpreter) {
auto* t_interpreter = t->pyobj_slot()->pyobj_interpreter();
if (t_interpreter) {
interpreter = t_interpreter;
}
}
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
// NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef
@ -82,14 +90,43 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
if (nv.isNone()) {
continue;
}
auto* interpreter = nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
if (interpreter) {
(*interpreter)->dispatch(op, stack);
return;
auto* t = nv.unsafeToTensorImpl();
if (t->key_set().has(c10::DispatchKey::Python)) {
tensors_with_python_key_present = true;
}
if (!interpreter) {
auto* t_interpreter = t->pyobj_slot()->pyobj_interpreter();
if (t_interpreter) {
interpreter = t_interpreter;
}
}
}
}
}
if (interpreter) {
if (tensors_with_python_key_present) {
(*interpreter)->dispatch(op, stack);
} else {
// At this point, there are no modes in the stack and no tensors with the python key.
// so disable the python key before redispatching.
// See https://github.com/pytorch/pytorch/issues/136565
c10::DispatchKeySet keyset = dispatch_keys.remove(c10::DispatchKey::Python);
// Remove Python key from the included set as well (modes add it there).
c10::impl::LocalDispatchKeySet local_keyset = c10::impl::tls_local_dispatch_key_set();
c10::impl::ForceDispatchKeyGuard no_python_guard(
local_keyset.included_.remove(c10::DispatchKey::Python),
local_keyset.excluded_
);
op.redispatchBoxed(keyset, stack);
}
return;
}
TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)");
}

View File

@ -2595,6 +2595,28 @@ def forward(self, x_1):
e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
self.assertEqual(e.layout, torch.strided)
def test_wrapper_subclass_reentrant_dispatch_with_mode(self):
# Tests the interaction between a wrapper subclass using reentrant dispatch
# and a TorchDispatchMode. See https://github.com/pytorch/pytorch/issues/136565
# simple passthrough TorchDispatchMode
class CustomDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=..., kwargs=None):
return func(*args, **kwargs)
# derive from TwoTensor to minimize boilerplate
class MySubclass(TwoTensor):
def __torch_dispatch__(self, func, types, args, kwargs=None):
with torch.overrides.enable_reentrant_dispatch():
return func(args[0].a)
t = MySubclass(torch.rand(2), torch.rand(2))
with CustomDispatchMode():
res = t.clone()
self.assertEqual(res, t.a)
self.assertIs(type(res), torch.Tensor)
class TestPythonDispatcher(TestCase):
def test_basic(self):