From 9a96604800ad77ed95d52641f51890ecb54753de Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 22 Feb 2022 09:44:06 -0800 Subject: [PATCH] Revert D34318185: [pytorch][PR] Ensure that call before redispatch work well for PythonTLSSnapshot Test Plan: revert-hammer Differential Revision: D34318185 (https://github.com/pytorch/pytorch/commit/04c9e52ecc73f0d29c815afcc1371f4f985b6d97) Original commit changeset: abc30fe69176 Original Phabricator Diff: D34318185 (https://github.com/pytorch/pytorch/commit/04c9e52ecc73f0d29c815afcc1371f4f985b6d97) fbshipit-source-id: ba40c2e1eceb1c4b71ac6edefc64d01e174d9524 (cherry picked from commit f47961904d0bbc75fb7dc4e8d11fcb4eca8cbc2b) --- aten/src/ATen/core/PythonFallbackKernel.cpp | 22 ++++-------- test/test_python_dispatch.py | 40 --------------------- 2 files changed, 6 insertions(+), 56 deletions(-) diff --git a/aten/src/ATen/core/PythonFallbackKernel.cpp b/aten/src/ATen/core/PythonFallbackKernel.cpp index a6897dabd82..b5861253c1e 100644 --- a/aten/src/ATen/core/PythonFallbackKernel.cpp +++ b/aten/src/ATen/core/PythonFallbackKernel.cpp @@ -2,27 +2,15 @@ #include #include -#include - namespace { // TLS saving the state of the include/exclude sets on entry to the dispatcher // This is set in the pythonTLSSnapshot fallback and used by the Python fallback. -thread_local std::stack tls_on_entry; - -struct C10_API StashTLSStateGuard { - public: - StashTLSStateGuard(const c10::impl::LocalDispatchKeySet& key_set) { - tls_on_entry.push(key_set); - } - ~StashTLSStateGuard() { - tls_on_entry.pop(); - } -}; +thread_local c10::optional tls_on_entry; void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { - TORCH_INTERNAL_ASSERT(tls_on_entry.size() > 0); - c10::impl::ForceDispatchKeyGuard guard(tls_on_entry.top()); + TORCH_INTERNAL_ASSERT(tls_on_entry.has_value()); + c10::impl::ForceDispatchKeyGuard guard(tls_on_entry.value()); // If Python Mode is active, use its PyInterpreter for dispatch const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state(); @@ -66,9 +54,11 @@ void pythonTLSSnapshotFallback(const c10::OperatorHandle& op, c10::DispatchKeySe // A CompositeImplicitAutograd function may have been called just before this and so the tls here were never cleared // This is also why we don't need an RAII to ensure the tls is reset when exceptions happen - StashTLSStateGuard guard(c10::impl::tls_local_dispatch_key_set()); + tls_on_entry = c10::impl::tls_local_dispatch_key_set(); op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack); + + tls_on_entry = c10::nullopt; } diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index fafcf534e43..a3e7e545799 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -562,46 +562,6 @@ $6 = torch._ops.aten.add_($1, $5)''') self.assertIsNone(t.grad) self.assertIsNotNone(t.elem.grad) - def test_multiple_ops_subclass(self): - # This is a Direct Subclass, don't do that! - class MySubclass(torch.Tensor): - @staticmethod - def __new__(cls, elem): - r = torch.Tensor._make_subclass(cls, elem) - return r - - __torch_function__ = torch._C._disabled_torch_function_impl - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - with no_dispatch(): - return func(*args, **kwargs) - - x = MySubclass(torch.rand(2, 2, dtype=torch.complex64)) - y = x.conj() - # Details of the bug that this tests for: - # Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU} - # There are a few calls to the dispatcher that are going to happen here: - # - call_exp: User calling exp on y - # - PythonTLSSnapshot: records the TLS on entry and redispatch - # - AutogradCPU: no input requires grad, so does nothing and redispatch - # - Conjugate: no special implementation for exp: use the fallback that - # first clone the Tensor (to materialize the conj) then redispatch - # - call_clone: conjugate fallback calling clone on y - # - PythonTLSSnapshot: records the TLS on entry and redispatch - # - (AutogradCPU: skipped as autograd added itself to the exclude set above) - # - Conjugate: special implementation for clone: just skip this key - # - Python: Reset the TLS based on the snapshot above and call the user implementation (this - # actually calls into the dispatcher again but since we disable both our keys - # before, not detailed here) - # - exit Python: restore the TLS and exit - # - exit Conjugate: nothing was inplace so just exit - # - exit PythonTLSSnapshot: done with this call, reset the saved TLS to empty - # - Python: Reset the TLS again based on the snapshot. <- this used to fail - # - More steps.... - y.exp() - - if __name__ == '__main__': run_tests()