Revert D34214953: Add new tls snapshot feature

Test Plan: revert-hammer

Differential Revision:
D34214953 (6199b5231f)

Original commit changeset: 7aa5d5e3540a

Original Phabricator Diff: D34214953 (6199b5231f)

fbshipit-source-id: 5d271e9a5ab021b8202402630dbf917b43c55421
This commit is contained in:
Brian Hirsh 2022-02-14 15:09:10 -08:00 committed by Facebook GitHub Bot
parent 4f7338d4f4
commit a12c630198
10 changed files with 21 additions and 67 deletions

View File

@ -4,14 +4,7 @@
namespace {
// TLS saving the state of the include/exclude sets on entry to the dispatcher
// This is set in the pythonTLSSnapshot fallback and used by the Python fallback.
thread_local c10::optional<c10::impl::LocalDispatchKeySet> tls_on_entry;
void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
c10::impl::ForceDispatchKeyGuard guard(tls_on_entry.value());
// If Python Mode is active, use its PyInterpreter for dispatch
const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
if (maybe_python_mode_state) {
@ -49,25 +42,8 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)");
}
void pythonTLSSnapshotFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
// It is ok for the tls to be already set here.
// A CompositeImplicitAutograd function may have been called just before this and so the tls here were never cleared
// This is also why we don't need an RAII to ensure the tls is reset when exceptions happen
tls_on_entry = c10::impl::tls_local_dispatch_key_set();
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack);
tls_on_entry = c10::nullopt;
}
} // anonymous namespace
TORCH_LIBRARY_IMPL(_, Python, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallback>());
}
TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>());
}

View File

@ -8,7 +8,6 @@ void PythonModeTLS::set_state(const std::shared_ptr<TorchDispatchTypeObject>& st
pythonModeState = state;
if (state) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(DispatchKey::PythonTLSSnapshot, true);
} else {
PythonModeTLS::reset_state();
}
@ -21,7 +20,6 @@ const std::shared_ptr<TorchDispatchTypeObject>& PythonModeTLS::get_state() {
void PythonModeTLS::reset_state() {
pythonModeState.reset((TorchDispatchTypeObject*)nullptr);
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(DispatchKey::PythonTLSSnapshot, false);
}
} // namespace impl

View File

@ -100,8 +100,6 @@ const char* toString(DispatchKey t) {
case DispatchKey::Python:
return "Python";
case DispatchKey::PythonTLSSnapshot:
return "PythonTLSSnapshot";
case DispatchKey::PrivateUse1:
return "PrivateUse1";
@ -253,7 +251,6 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
{"BackendSelect", c10::DispatchKey::BackendSelect},
{"Python", c10::DispatchKey::Python},
{"PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot},
{"Named", c10::DispatchKey::Named},
{"Conjugate", c10::DispatchKey::Conjugate},
{"Negative", c10::DispatchKey::Negative},

View File

@ -354,11 +354,6 @@ enum class DispatchKey : uint16_t {
Functionalize,
FuncTorchDynamicLayerFrontMode, // See Note [Out-of-tree vmap+grad prototype]
// Used by Python key logic to know the set of tls on entry to the dispatcher
// This kernel assumes it is at the very top of the dispatcher. If you add
// a key above, make sure to update the fallback implementation for this.
PythonTLSSnapshot,
// TESTING: This is intended to be a generic testing tensor type id.
// Don't use it for anything real; its only acceptable use is within a single
// process test. Use it by creating a TensorImpl with this DispatchKey, and

View File

@ -606,10 +606,7 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView);
constexpr DispatchKeySet python_ks = DispatchKeySet({
DispatchKey::Python,
DispatchKey::PythonTLSSnapshot,
});
constexpr DispatchKeySet python_ks = DispatchKeySet(DispatchKey::Python);
constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);

View File

@ -120,11 +120,11 @@ TensorImpl::TensorImpl(
// [Note: Python key removal]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// In most constructors for TensorImpl, you will see Python and PythonTLSSnapshot
// keys are removed from the passed in DispatchKeySet. Why?
// In most constructors for TensorImpl, you will see Python key is removed from
// the passed in DispatchKeySet. Why?
//
// INVARIANT: Python and PythonTLSSnapshot dispatch keys are set iff PyObject for
// the Tensor has a nontrivial __torch_dispatch__ implementation.
// INVARIANT: Python dispatch key is set iff PyObject for the Tensor has a
// nontrivial __torch_dispatch__ implementation.
//
// When a fresh TensorImpl is created, there is *no* PyObject (this only gets
// initialized lazily at the first point in time the Tensor passes into Python).
@ -132,8 +132,8 @@ TensorImpl::TensorImpl(
//
// In practice, what will happen shortly afterwards is that the TensorImpl
// will get its PyObject initialized by Tensor._make_subclass; at this point
// the Python and PythonTLSSnapshot dispatch keys will be set and all is well.
// The point is to delay the dispatch key setting until that point.
// the Python dispatch key will be set and all is well. The point is to delay
// the dispatch key setting until that point.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorImpl::TensorImpl(
@ -552,7 +552,7 @@ void TensorImpl::copy_tensor_metadata_except_version_counter(
dest_impl->storage_offset_ = src_impl->storage_offset_;
dest_impl->data_type_ = src_impl->data_type_;
dest_impl->device_opt_ = src_impl->device_opt_;
dest_impl->key_set_ = src_impl->key_set_.remove(DispatchKey::Python).remove(DispatchKey::PythonTLSSnapshot);
dest_impl->key_set_ = src_impl->key_set_.remove(DispatchKey::Python);
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
dest_impl->has_contiguity_ = src_impl->has_contiguity_;
dest_impl->is_channels_last_contiguous_ =

View File

@ -117,20 +117,6 @@ class C10_API ExcludeDispatchKeyGuard {
DispatchKeySet exclude_;
};
struct C10_API ForceDispatchKeyGuard {
public:
ForceDispatchKeyGuard(c10::impl::LocalDispatchKeySet key_set) :
saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {
c10::impl::_force_tls_local_dispatch_key_set(key_set);
}
~ForceDispatchKeyGuard() {
c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_);
}
private:
c10::impl::LocalDispatchKeySet saved_keyset_;
};
// Non-RAII API for manipulating the thread-local dispatch state.
// Please prefer the RAII API. The non-RAII API may be useful when
// the included/excluded state of a given DispatchKey must span

View File

@ -551,16 +551,21 @@ $6 = torch._ops.aten.add_($1, $5)''')
self.assertFalse(out.requires_grad)
self.assertIsNone(out.grad_fn)
self.assertTrue(out.elem.requires_grad)
self.assertIsNotNone(out.elem.grad_fn)
# TODO: this should be True
self.assertFalse(out.elem.requires_grad)
# TODO: this should be not None
self.assertIsNone(out.elem.grad_fn)
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
out.sum().backward()
out.backward()
out.elem.sum().backward()
# TODO: this should not raise
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
out.elem.backward()
self.assertIsNone(t.grad)
self.assertIsNotNone(t.elem.grad)
# TODO: this should not be None
self.assertIsNone(t.elem.grad)
if __name__ == '__main__':

View File

@ -27,10 +27,9 @@
#include <unordered_set>
struct DisableTorchDispatch {
DisableTorchDispatch() : guard_(c10::DispatchKey::Python),
guard_tls_snapshot_(c10::DispatchKey::PythonTLSSnapshot) {}
DisableTorchDispatch() : guard_(c10::DispatchKey::Python) {
}
c10::impl::ExcludeDispatchKeyGuard guard_;
c10::impl::ExcludeDispatchKeyGuard guard_tls_snapshot_;
};
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {

View File

@ -27,6 +27,7 @@ def no_dispatch() -> Iterator[None]:
# can require gradients if the user asks for it as a constructor kwarg.
# - The wrapped Tensor can require gradients. In that case autograd will be tracked
# for the wrapped Tensor and the LoggingTensor itself cannot require gradients.
# Note that this second one is not possible today as dispatcher exclude keys are not properly reset
# WARNING: We allow these two possibilities for testing purposes. You should NEVER use both in a single
# test or you might get surprising behavior.