mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D34214953: Add new tls snapshot feature
Test Plan: revert-hammer Differential Revision: D34214953 (6199b5231f) Original commit changeset: 7aa5d5e3540a Original Phabricator Diff: D34214953 (6199b5231f) fbshipit-source-id: 5d271e9a5ab021b8202402630dbf917b43c55421
This commit is contained in:
parent
4f7338d4f4
commit
a12c630198
|
|
@ -4,14 +4,7 @@
|
|||
|
||||
namespace {
|
||||
|
||||
// TLS saving the state of the include/exclude sets on entry to the dispatcher
|
||||
// This is set in the pythonTLSSnapshot fallback and used by the Python fallback.
|
||||
thread_local c10::optional<c10::impl::LocalDispatchKeySet> tls_on_entry;
|
||||
|
||||
void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
|
||||
c10::impl::ForceDispatchKeyGuard guard(tls_on_entry.value());
|
||||
|
||||
// If Python Mode is active, use its PyInterpreter for dispatch
|
||||
const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
|
||||
if (maybe_python_mode_state) {
|
||||
|
|
@ -49,25 +42,8 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
|||
TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)");
|
||||
}
|
||||
|
||||
void pythonTLSSnapshotFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
|
||||
// It is ok for the tls to be already set here.
|
||||
// A CompositeImplicitAutograd function may have been called just before this and so the tls here were never cleared
|
||||
// This is also why we don't need an RAII to ensure the tls is reset when exceptions happen
|
||||
|
||||
tls_on_entry = c10::impl::tls_local_dispatch_key_set();
|
||||
|
||||
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack);
|
||||
|
||||
tls_on_entry = c10::nullopt;
|
||||
}
|
||||
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, Python, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallback>());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ void PythonModeTLS::set_state(const std::shared_ptr<TorchDispatchTypeObject>& st
|
|||
pythonModeState = state;
|
||||
if (state) {
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::PythonTLSSnapshot, true);
|
||||
} else {
|
||||
PythonModeTLS::reset_state();
|
||||
}
|
||||
|
|
@ -21,7 +20,6 @@ const std::shared_ptr<TorchDispatchTypeObject>& PythonModeTLS::get_state() {
|
|||
void PythonModeTLS::reset_state() {
|
||||
pythonModeState.reset((TorchDispatchTypeObject*)nullptr);
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::PythonTLSSnapshot, false);
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
|
|
|||
|
|
@ -100,8 +100,6 @@ const char* toString(DispatchKey t) {
|
|||
|
||||
case DispatchKey::Python:
|
||||
return "Python";
|
||||
case DispatchKey::PythonTLSSnapshot:
|
||||
return "PythonTLSSnapshot";
|
||||
|
||||
case DispatchKey::PrivateUse1:
|
||||
return "PrivateUse1";
|
||||
|
|
@ -253,7 +251,6 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
|||
{"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
|
||||
{"BackendSelect", c10::DispatchKey::BackendSelect},
|
||||
{"Python", c10::DispatchKey::Python},
|
||||
{"PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot},
|
||||
{"Named", c10::DispatchKey::Named},
|
||||
{"Conjugate", c10::DispatchKey::Conjugate},
|
||||
{"Negative", c10::DispatchKey::Negative},
|
||||
|
|
|
|||
|
|
@ -354,11 +354,6 @@ enum class DispatchKey : uint16_t {
|
|||
Functionalize,
|
||||
FuncTorchDynamicLayerFrontMode, // See Note [Out-of-tree vmap+grad prototype]
|
||||
|
||||
// Used by Python key logic to know the set of tls on entry to the dispatcher
|
||||
// This kernel assumes it is at the very top of the dispatcher. If you add
|
||||
// a key above, make sure to update the fallback implementation for this.
|
||||
PythonTLSSnapshot,
|
||||
|
||||
// TESTING: This is intended to be a generic testing tensor type id.
|
||||
// Don't use it for anything real; its only acceptable use is within a single
|
||||
// process test. Use it by creating a TensorImpl with this DispatchKey, and
|
||||
|
|
|
|||
|
|
@ -606,10 +606,7 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
|
|||
constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
|
||||
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView);
|
||||
|
||||
constexpr DispatchKeySet python_ks = DispatchKeySet({
|
||||
DispatchKey::Python,
|
||||
DispatchKey::PythonTLSSnapshot,
|
||||
});
|
||||
constexpr DispatchKeySet python_ks = DispatchKeySet(DispatchKey::Python);
|
||||
|
||||
constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);
|
||||
|
||||
|
|
|
|||
|
|
@ -120,11 +120,11 @@ TensorImpl::TensorImpl(
|
|||
|
||||
// [Note: Python key removal]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// In most constructors for TensorImpl, you will see Python and PythonTLSSnapshot
|
||||
// keys are removed from the passed in DispatchKeySet. Why?
|
||||
// In most constructors for TensorImpl, you will see Python key is removed from
|
||||
// the passed in DispatchKeySet. Why?
|
||||
//
|
||||
// INVARIANT: Python and PythonTLSSnapshot dispatch keys are set iff PyObject for
|
||||
// the Tensor has a nontrivial __torch_dispatch__ implementation.
|
||||
// INVARIANT: Python dispatch key is set iff PyObject for the Tensor has a
|
||||
// nontrivial __torch_dispatch__ implementation.
|
||||
//
|
||||
// When a fresh TensorImpl is created, there is *no* PyObject (this only gets
|
||||
// initialized lazily at the first point in time the Tensor passes into Python).
|
||||
|
|
@ -132,8 +132,8 @@ TensorImpl::TensorImpl(
|
|||
//
|
||||
// In practice, what will happen shortly afterwards is that the TensorImpl
|
||||
// will get its PyObject initialized by Tensor._make_subclass; at this point
|
||||
// the Python and PythonTLSSnapshot dispatch keys will be set and all is well.
|
||||
// The point is to delay the dispatch key setting until that point.
|
||||
// the Python dispatch key will be set and all is well. The point is to delay
|
||||
// the dispatch key setting until that point.
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
TensorImpl::TensorImpl(
|
||||
|
|
@ -552,7 +552,7 @@ void TensorImpl::copy_tensor_metadata_except_version_counter(
|
|||
dest_impl->storage_offset_ = src_impl->storage_offset_;
|
||||
dest_impl->data_type_ = src_impl->data_type_;
|
||||
dest_impl->device_opt_ = src_impl->device_opt_;
|
||||
dest_impl->key_set_ = src_impl->key_set_.remove(DispatchKey::Python).remove(DispatchKey::PythonTLSSnapshot);
|
||||
dest_impl->key_set_ = src_impl->key_set_.remove(DispatchKey::Python);
|
||||
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
|
||||
dest_impl->has_contiguity_ = src_impl->has_contiguity_;
|
||||
dest_impl->is_channels_last_contiguous_ =
|
||||
|
|
|
|||
|
|
@ -117,20 +117,6 @@ class C10_API ExcludeDispatchKeyGuard {
|
|||
DispatchKeySet exclude_;
|
||||
};
|
||||
|
||||
struct C10_API ForceDispatchKeyGuard {
|
||||
public:
|
||||
ForceDispatchKeyGuard(c10::impl::LocalDispatchKeySet key_set) :
|
||||
saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {
|
||||
c10::impl::_force_tls_local_dispatch_key_set(key_set);
|
||||
}
|
||||
~ForceDispatchKeyGuard() {
|
||||
c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_);
|
||||
}
|
||||
|
||||
private:
|
||||
c10::impl::LocalDispatchKeySet saved_keyset_;
|
||||
};
|
||||
|
||||
// Non-RAII API for manipulating the thread-local dispatch state.
|
||||
// Please prefer the RAII API. The non-RAII API may be useful when
|
||||
// the included/excluded state of a given DispatchKey must span
|
||||
|
|
|
|||
|
|
@ -551,16 +551,21 @@ $6 = torch._ops.aten.add_($1, $5)''')
|
|||
self.assertFalse(out.requires_grad)
|
||||
self.assertIsNone(out.grad_fn)
|
||||
|
||||
self.assertTrue(out.elem.requires_grad)
|
||||
self.assertIsNotNone(out.elem.grad_fn)
|
||||
# TODO: this should be True
|
||||
self.assertFalse(out.elem.requires_grad)
|
||||
# TODO: this should be not None
|
||||
self.assertIsNone(out.elem.grad_fn)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
||||
out.sum().backward()
|
||||
out.backward()
|
||||
|
||||
out.elem.sum().backward()
|
||||
# TODO: this should not raise
|
||||
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
||||
out.elem.backward()
|
||||
|
||||
self.assertIsNone(t.grad)
|
||||
self.assertIsNotNone(t.elem.grad)
|
||||
# TODO: this should not be None
|
||||
self.assertIsNone(t.elem.grad)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -27,10 +27,9 @@
|
|||
#include <unordered_set>
|
||||
|
||||
struct DisableTorchDispatch {
|
||||
DisableTorchDispatch() : guard_(c10::DispatchKey::Python),
|
||||
guard_tls_snapshot_(c10::DispatchKey::PythonTLSSnapshot) {}
|
||||
DisableTorchDispatch() : guard_(c10::DispatchKey::Python) {
|
||||
}
|
||||
c10::impl::ExcludeDispatchKeyGuard guard_;
|
||||
c10::impl::ExcludeDispatchKeyGuard guard_tls_snapshot_;
|
||||
};
|
||||
|
||||
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ def no_dispatch() -> Iterator[None]:
|
|||
# can require gradients if the user asks for it as a constructor kwarg.
|
||||
# - The wrapped Tensor can require gradients. In that case autograd will be tracked
|
||||
# for the wrapped Tensor and the LoggingTensor itself cannot require gradients.
|
||||
# Note that this second one is not possible today as dispatcher exclude keys are not properly reset
|
||||
# WARNING: We allow these two possibilities for testing purposes. You should NEVER use both in a single
|
||||
# test or you might get surprising behavior.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user