mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D34318185: [pytorch][PR] Ensure that call before redispatch work well for PythonTLSSnapshot
Test Plan: revert-hammer Differential Revision: D34318185 (04c9e52ecc) Original commit changeset: abc30fe69176 Original Phabricator Diff: D34318185 (04c9e52ecc) fbshipit-source-id: ba40c2e1eceb1c4b71ac6edefc64d01e174d9524 (cherry picked from commitf47961904d)
This commit is contained in:
parent
932adf26e4
commit
9a96604800
|
|
@ -2,27 +2,15 @@
|
|||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/PythonModeTLS.h>
|
||||
|
||||
#include <stack>
|
||||
|
||||
namespace {
|
||||
|
||||
// TLS saving the state of the include/exclude sets on entry to the dispatcher
|
||||
// This is set in the pythonTLSSnapshot fallback and used by the Python fallback.
|
||||
thread_local std::stack<c10::impl::LocalDispatchKeySet> tls_on_entry;
|
||||
|
||||
struct C10_API StashTLSStateGuard {
|
||||
public:
|
||||
StashTLSStateGuard(const c10::impl::LocalDispatchKeySet& key_set) {
|
||||
tls_on_entry.push(key_set);
|
||||
}
|
||||
~StashTLSStateGuard() {
|
||||
tls_on_entry.pop();
|
||||
}
|
||||
};
|
||||
thread_local c10::optional<c10::impl::LocalDispatchKeySet> tls_on_entry;
|
||||
|
||||
void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
TORCH_INTERNAL_ASSERT(tls_on_entry.size() > 0);
|
||||
c10::impl::ForceDispatchKeyGuard guard(tls_on_entry.top());
|
||||
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
|
||||
c10::impl::ForceDispatchKeyGuard guard(tls_on_entry.value());
|
||||
|
||||
// If Python Mode is active, use its PyInterpreter for dispatch
|
||||
const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
|
||||
|
|
@ -66,9 +54,11 @@ void pythonTLSSnapshotFallback(const c10::OperatorHandle& op, c10::DispatchKeySe
|
|||
// A CompositeImplicitAutograd function may have been called just before this and so the tls here were never cleared
|
||||
// This is also why we don't need an RAII to ensure the tls is reset when exceptions happen
|
||||
|
||||
StashTLSStateGuard guard(c10::impl::tls_local_dispatch_key_set());
|
||||
tls_on_entry = c10::impl::tls_local_dispatch_key_set();
|
||||
|
||||
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack);
|
||||
|
||||
tls_on_entry = c10::nullopt;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -562,46 +562,6 @@ $6 = torch._ops.aten.add_($1, $5)''')
|
|||
self.assertIsNone(t.grad)
|
||||
self.assertIsNotNone(t.elem.grad)
|
||||
|
||||
def test_multiple_ops_subclass(self):
|
||||
# This is a Direct Subclass, don't do that!
|
||||
class MySubclass(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
r = torch.Tensor._make_subclass(cls, elem)
|
||||
return r
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
with no_dispatch():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
x = MySubclass(torch.rand(2, 2, dtype=torch.complex64))
|
||||
y = x.conj()
|
||||
# Details of the bug that this tests for:
|
||||
# Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU}
|
||||
# There are a few calls to the dispatcher that are going to happen here:
|
||||
# - call_exp: User calling exp on y
|
||||
# - PythonTLSSnapshot: records the TLS on entry and redispatch
|
||||
# - AutogradCPU: no input requires grad, so does nothing and redispatch
|
||||
# - Conjugate: no special implementation for exp: use the fallback that
|
||||
# first clone the Tensor (to materialize the conj) then redispatch
|
||||
# - call_clone: conjugate fallback calling clone on y
|
||||
# - PythonTLSSnapshot: records the TLS on entry and redispatch
|
||||
# - (AutogradCPU: skipped as autograd added itself to the exclude set above)
|
||||
# - Conjugate: special implementation for clone: just skip this key
|
||||
# - Python: Reset the TLS based on the snapshot above and call the user implementation (this
|
||||
# actually calls into the dispatcher again but since we disable both our keys
|
||||
# before, not detailed here)
|
||||
# - exit Python: restore the TLS and exit
|
||||
# - exit Conjugate: nothing was inplace so just exit
|
||||
# - exit PythonTLSSnapshot: done with this call, reset the saved TLS to empty
|
||||
# - Python: Reset the TLS again based on the snapshot. <- this used to fail
|
||||
# - More steps....
|
||||
y.exp()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user