mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix wrapper subclass reentrant dispatch + TorchDispatchMode (#136566)
Fixes #136565 This PR makes the python fallback robust to the case where there are no active modes & no tensors with the Python key. In this case, simply redispatch with the Python key disabled. This was found when trying to use reentrant dispatch for NJT to get decompositions under `inference_mode()` when the autograd key is disabled. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136566 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
963e793e1b
commit
f8debd5d83
|
|
@ -45,7 +45,7 @@ private:
|
|||
c10::impl::LocalDispatchKeySet saved_;
|
||||
};
|
||||
|
||||
void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
void pythonFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
|
||||
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
|
||||
// c10::impl::ForceDispatchKeyGuard dispatcher_guard(tls_on_entry.value());
|
||||
// StashTLSOnEntryGuard stash_guard;
|
||||
|
|
@ -68,12 +68,20 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
|||
// we actually run dispatch(), we will take out PyObjects in the context
|
||||
// of that interpreter, and this will ensure that everyone is on the same
|
||||
// interpreter.
|
||||
bool tensors_with_python_key_present = false;
|
||||
c10::impl::PyInterpreter* interpreter = nullptr;
|
||||
for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
|
||||
if (ivalue.isTensor()) {
|
||||
auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
|
||||
if (interpreter) {
|
||||
(*interpreter)->dispatch(op, stack);
|
||||
return;
|
||||
auto* t = ivalue.unsafeToTensorImpl();
|
||||
if (t->key_set().has(c10::DispatchKey::Python)) {
|
||||
tensors_with_python_key_present = true;
|
||||
}
|
||||
|
||||
if (!interpreter) {
|
||||
auto* t_interpreter = t->pyobj_slot()->pyobj_interpreter();
|
||||
if (t_interpreter) {
|
||||
interpreter = t_interpreter;
|
||||
}
|
||||
}
|
||||
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
|
||||
// NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef
|
||||
|
|
@ -82,14 +90,43 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
|||
if (nv.isNone()) {
|
||||
continue;
|
||||
}
|
||||
auto* interpreter = nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
|
||||
if (interpreter) {
|
||||
(*interpreter)->dispatch(op, stack);
|
||||
return;
|
||||
|
||||
auto* t = nv.unsafeToTensorImpl();
|
||||
if (t->key_set().has(c10::DispatchKey::Python)) {
|
||||
tensors_with_python_key_present = true;
|
||||
}
|
||||
|
||||
if (!interpreter) {
|
||||
auto* t_interpreter = t->pyobj_slot()->pyobj_interpreter();
|
||||
if (t_interpreter) {
|
||||
interpreter = t_interpreter;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (interpreter) {
|
||||
if (tensors_with_python_key_present) {
|
||||
(*interpreter)->dispatch(op, stack);
|
||||
} else {
|
||||
// At this point, there are no modes in the stack and no tensors with the python key.
|
||||
// so disable the python key before redispatching.
|
||||
// See https://github.com/pytorch/pytorch/issues/136565
|
||||
c10::DispatchKeySet keyset = dispatch_keys.remove(c10::DispatchKey::Python);
|
||||
|
||||
// Remove Python key from the included set as well (modes add it there).
|
||||
c10::impl::LocalDispatchKeySet local_keyset = c10::impl::tls_local_dispatch_key_set();
|
||||
c10::impl::ForceDispatchKeyGuard no_python_guard(
|
||||
local_keyset.included_.remove(c10::DispatchKey::Python),
|
||||
local_keyset.excluded_
|
||||
);
|
||||
|
||||
op.redispatchBoxed(keyset, stack);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2595,6 +2595,28 @@ def forward(self, x_1):
|
|||
e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
|
||||
self.assertEqual(e.layout, torch.strided)
|
||||
|
||||
def test_wrapper_subclass_reentrant_dispatch_with_mode(self):
|
||||
# Tests the interaction between a wrapper subclass using reentrant dispatch
|
||||
# and a TorchDispatchMode. See https://github.com/pytorch/pytorch/issues/136565
|
||||
|
||||
# simple passthrough TorchDispatchMode
|
||||
class CustomDispatchMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=..., kwargs=None):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# derive from TwoTensor to minimize boilerplate
|
||||
class MySubclass(TwoTensor):
|
||||
def __torch_dispatch__(self, func, types, args, kwargs=None):
|
||||
with torch.overrides.enable_reentrant_dispatch():
|
||||
return func(args[0].a)
|
||||
|
||||
t = MySubclass(torch.rand(2), torch.rand(2))
|
||||
with CustomDispatchMode():
|
||||
res = t.clone()
|
||||
|
||||
self.assertEqual(res, t.a)
|
||||
self.assertIs(type(res), torch.Tensor)
|
||||
|
||||
|
||||
class TestPythonDispatcher(TestCase):
|
||||
def test_basic(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user