Revert D34318185: [pytorch][PR] Ensure that call before redispatch work well for PythonTLSSnapshot

Test Plan: revert-hammer

Differential Revision:
D34318185 (04c9e52ecc)

Original commit changeset: abc30fe69176

Original Phabricator Diff: D34318185 (04c9e52ecc)

fbshipit-source-id: ba40c2e1eceb1c4b71ac6edefc64d01e174d9524
This commit is contained in:
Nikita Shulga 2022-02-22 09:44:06 -08:00 committed by Facebook GitHub Bot
parent a72621fbbd
commit f47961904d
2 changed files with 6 additions and 56 deletions

View File

@ -2,27 +2,15 @@
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/PythonModeTLS.h>
#include <stack>
namespace {
// TLS saving the state of the include/exclude sets on entry to the dispatcher
// This is set in the pythonTLSSnapshot fallback and used by the Python fallback.
thread_local std::stack<c10::impl::LocalDispatchKeySet> tls_on_entry;
struct C10_API StashTLSStateGuard {
public:
StashTLSStateGuard(const c10::impl::LocalDispatchKeySet& key_set) {
tls_on_entry.push(key_set);
}
~StashTLSStateGuard() {
tls_on_entry.pop();
}
};
thread_local c10::optional<c10::impl::LocalDispatchKeySet> tls_on_entry;
void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
TORCH_INTERNAL_ASSERT(tls_on_entry.size() > 0);
c10::impl::ForceDispatchKeyGuard guard(tls_on_entry.top());
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
c10::impl::ForceDispatchKeyGuard guard(tls_on_entry.value());
// If Python Mode is active, use its PyInterpreter for dispatch
const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
@ -66,9 +54,11 @@ void pythonTLSSnapshotFallback(const c10::OperatorHandle& op, c10::DispatchKeySe
// A CompositeImplicitAutograd function may have been called just before this and so the tls here were never cleared
// This is also why we don't need an RAII to ensure the tls is reset when exceptions happen
StashTLSStateGuard guard(c10::impl::tls_local_dispatch_key_set());
tls_on_entry = c10::impl::tls_local_dispatch_key_set();
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack);
tls_on_entry = c10::nullopt;
}

View File

@ -562,46 +562,6 @@ $6 = torch._ops.aten.add_($1, $5)''')
self.assertIsNone(t.grad)
self.assertIsNotNone(t.elem.grad)
def test_multiple_ops_subclass(self):
# This is a Direct Subclass, don't do that!
class MySubclass(torch.Tensor):
@staticmethod
def __new__(cls, elem):
r = torch.Tensor._make_subclass(cls, elem)
return r
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
with no_dispatch():
return func(*args, **kwargs)
x = MySubclass(torch.rand(2, 2, dtype=torch.complex64))
y = x.conj()
# Details of the bug that this tests for:
# Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU}
# There are a few calls to the dispatcher that are going to happen here:
# - call_exp: User calling exp on y
# - PythonTLSSnapshot: records the TLS on entry and redispatch
# - AutogradCPU: no input requires grad, so does nothing and redispatch
# - Conjugate: no special implementation for exp: use the fallback that
# first clone the Tensor (to materialize the conj) then redispatch
# - call_clone: conjugate fallback calling clone on y
# - PythonTLSSnapshot: records the TLS on entry and redispatch
# - (AutogradCPU: skipped as autograd added itself to the exclude set above)
# - Conjugate: special implementation for clone: just skip this key
# - Python: Reset the TLS based on the snapshot above and call the user implementation (this
# actually calls into the dispatcher again but since we disable both our keys
# before, not detailed here)
# - exit Python: restore the TLS and exit
# - exit Conjugate: nothing was inplace so just exit
# - exit PythonTLSSnapshot: done with this call, reset the saved TLS to empty
# - Python: Reset the TLS again based on the snapshot. <- this used to fail
# - More steps....
y.exp()
if __name__ == '__main__':
run_tests()