pytorch/c10/core/impl/LocalDispatchKeySet.cpp
Edward Yang dd64e738c5 Expunge TensorId from all DispatchKey names. (#36240)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36240

It's annoying, historical, and unnecessary (enum class is already
namespaced).  I did this codemod with:

```
git grep -l 'CPUTensorId' | xargs sed -i 's/CPUTensorId/CPU/g'
git grep -l 'CUDATensorId' | xargs sed -i 's/CUDATensorId/CUDA/g'
git grep -l 'VariableTensorId' | xargs sed -i 's/VariableTensorId/Autograd/g'
git grep -l 'HIPTensorId' | xargs sed -i 's/HIPTensorId/HIP/g'
git grep -l 'MSNPUTensorId' | xargs sed -i 's/MSNPUTensorId/MSNPU/g'
git grep -l 'XLATensorId' | xargs sed -i 's/XLATensorId/XLA/g'
git grep -l 'PrivateUse1_TensorId' | xargs sed -i 's/PrivateUse1_TensorId/PrivateUse1/g'
git grep -l 'PrivateUse2_TensorId' | xargs sed -i 's/PrivateUse2_TensorId/PrivateUse2/g'
git grep -l 'PrivateUse3_TensorId' | xargs sed -i 's/PrivateUse3_TensorId/PrivateUse3/g'
git grep -l 'AutocastTensorId' | xargs sed -i 's/AutocastTensorId/Autocast/g'
git grep -l '_PreAutogradTensorId' | xargs sed -i 's/_PreAutogradTensorId/_PreAutograd/g'
git grep -l 'TESTING_ONLY_GenericWrapperTensorId' | xargs sed -i 's/TESTING_ONLY_GenericWrapperTensorId/TESTING_ONLY_GenericWrapper/g'
git grep -l 'TESTING_ONLY_GenericModeTensorId' | xargs sed -i 's/TESTING_ONLY_GenericModeTensorId/TESTING_ONLY_GenericMode/g'
```

Then I did a git grep for remaining TensorId occurrences, and manually
killed those (mostly in codegen, and some docs that needed updating).

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D20929255

Pulled By: ezyang

fbshipit-source-id: dc371b6aa6e6ea7c0a5660137c14debde806a09d
2020-04-13 23:33:44 -07:00

133 lines
4.2 KiB
C++

#include <c10/core/impl/LocalDispatchKeySet.h>
#include <iostream>
namespace c10 {
namespace impl {
C10_DEFINE_bool(disable_variable_dispatch, false, "This flag forcibly disables the Variable code paths from executing, which currently breaks profiling in the process.");
namespace {
/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting,
/// thread_local is not supported.
#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY
// NB: POD, zero initialized!
thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY)
static PODLocalDispatchKeySet raw_local_dispatch_key_set;
#endif
} // anonymous namespace
LocalDispatchKeySet tls_local_dispatch_key_set() {
// Hack until variable performance is fixed
//
// ezyang: I'm pretty unhappy about this implementation, it looks wrong
// to me, as it seems to be performing a mutation on
// raw_local_dispatch_key_set. I can't conveniently test the correct
// version though...
if (FLAGS_disable_variable_dispatch) {
raw_local_dispatch_key_set.set_excluded(
raw_local_dispatch_key_set.excluded().add(
DispatchKey::Autograd));
}
return raw_local_dispatch_key_set;
}
void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set) {
raw_local_dispatch_key_set = PODLocalDispatchKeySet {
key_set.included_.raw_repr(),
key_set.excluded_.raw_repr()
};
}
// An RAII guard could snapshot and restore the entire state (entire DispatchKeySet) as
// opposed to only snapshotting and restoring the state of its assigned DispatchKey.
// I'm not sure which is better. If only the RAII API is used, the two choices are
// not distinguishable.
//
// However, if the guard chooses to snapshot and restore the entire DispatchKeySet,
// the interaction with the non-RAII API changes. Consider this sequence of events:
// - An RAII guard is declared for a particular DispatchKey, but snapshots the entire
// current DispatchKeySet.
// - A call to the non-RAII API changes the state for a different DispatchKey.
// - The RAII guard goes out of scope, restoring the entire DispatchKeySet it snapshotted
// (which restores the state for its own assigned DispatchKey and wipes out the state
// for the other DispatchKey set by the non-RAII API).
// RAII API
IncludeDispatchKeyGuard::IncludeDispatchKeyGuard(DispatchKey x)
: tls_(&raw_local_dispatch_key_set)
, id_(x)
// NB: prev_state_ == true on Undefined makes the guard no-op
, prev_state_(x == DispatchKey::Undefined ? true : tls_->included().has(x)) {
if (!prev_state_) {
tls_->set_included(tls_->included().add(x));
}
}
IncludeDispatchKeyGuard::~IncludeDispatchKeyGuard() {
if (!prev_state_) {
tls_->set_included(tls_->included().remove(id_));
}
}
ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(DispatchKey x)
: tls_(&raw_local_dispatch_key_set)
, id_(x)
// NB: prev_state_ == true on Undefined makes the guard no-op
, prev_state_(x == DispatchKey::Undefined ? true : tls_->excluded().has(x)) {
if (!prev_state_) {
tls_->set_excluded(tls_->excluded().add(x));
}
}
ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() {
if (!prev_state_) {
tls_->set_excluded(tls_->excluded().remove(id_));
}
}
// Non-RAII API
// Please prefer using the RAII API. See declarations in LocalDispatchKeySet.h for details.
bool tls_is_dispatch_key_excluded(DispatchKey x) {
return raw_local_dispatch_key_set.excluded().has(x);
}
void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state) {
auto* tls = &raw_local_dispatch_key_set;
bool current_state = tls->excluded().has(x);
if (desired_state != current_state) {
if (desired_state) {
tls->set_excluded(tls->excluded().add(x));
} else {
tls->set_excluded(tls->excluded().remove(x));
}
}
}
bool tls_is_dispatch_key_included(DispatchKey x) {
return raw_local_dispatch_key_set.included().has(x);
}
void tls_set_dispatch_key_included(DispatchKey x, bool desired_state) {
auto* tls = &raw_local_dispatch_key_set;
bool current_state = tls->included().has(x);
if (desired_state != current_state) {
if (desired_state) {
tls->set_included(tls->included().add(x));
} else {
tls->set_included(tls->included().remove(x));
}
}
}
}} // namespace c10::impl