pytorch/c10/core/impl/LocalDispatchKeySet.h
Edward Yang dd64e738c5 Expunge TensorId from all DispatchKey names. (#36240)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36240

It's annoying, historical, and unnecessary (enum class is already
namespaced).  I did this codemod with:

```
git grep -l 'CPUTensorId' | xargs sed -i 's/CPUTensorId/CPU/g'
git grep -l 'CUDATensorId' | xargs sed -i 's/CUDATensorId/CUDA/g'
git grep -l 'VariableTensorId' | xargs sed -i 's/VariableTensorId/Autograd/g'
git grep -l 'HIPTensorId' | xargs sed -i 's/HIPTensorId/HIP/g'
git grep -l 'MSNPUTensorId' | xargs sed -i 's/MSNPUTensorId/MSNPU/g'
git grep -l 'XLATensorId' | xargs sed -i 's/XLATensorId/XLA/g'
git grep -l 'PrivateUse1_TensorId' | xargs sed -i 's/PrivateUse1_TensorId/PrivateUse1/g'
git grep -l 'PrivateUse2_TensorId' | xargs sed -i 's/PrivateUse2_TensorId/PrivateUse2/g'
git grep -l 'PrivateUse3_TensorId' | xargs sed -i 's/PrivateUse3_TensorId/PrivateUse3/g'
git grep -l 'AutocastTensorId' | xargs sed -i 's/AutocastTensorId/Autocast/g'
git grep -l '_PreAutogradTensorId' | xargs sed -i 's/_PreAutogradTensorId/_PreAutograd/g'
git grep -l 'TESTING_ONLY_GenericWrapperTensorId' | xargs sed -i 's/TESTING_ONLY_GenericWrapperTensorId/TESTING_ONLY_GenericWrapper/g'
git grep -l 'TESTING_ONLY_GenericModeTensorId' | xargs sed -i 's/TESTING_ONLY_GenericModeTensorId/TESTING_ONLY_GenericMode/g'
```

Then I did a git grep for remaining TensorId occurrences, and manually
killed those (mostly in codegen, and some docs that needed updating).

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D20929255

Pulled By: ezyang

fbshipit-source-id: dc371b6aa6e6ea7c0a5660137c14debde806a09d
2020-04-13 23:33:44 -07:00

115 lines
4.1 KiB
C++

#pragma once
#include <c10/core/DispatchKeySet.h>
#include <c10/util/Flags.h>
// TLS management for DispatchKeySet (the "local" DispatchKeySet(s))
//
// This manages two thread-local DispatchKeySets:
//
// - The included type set, which adds a tensor type for consideration
// in dispatch. (For example, you might add Profiling to
// the included type set to turn on profiling on all tensor operations.)
//
// - The excluded type set, which disqualifies a tensor type from dispatch.
// (For example, after redispatching on variable, we disqualify
// Autograd so we don't attempt to handle variable again.)
// (Exclusion wins over inclusion.)
//
// NB: Originally, I implemented the excluded type set as storing the inverted
// set, but TLS is defined to be zero-initialized, so this doesn't actually work
// (if it's inverted, you want the set to be -1 initialized).
namespace c10 {
namespace impl {
C10_DECLARE_bool(disable_variable_dispatch);
// POD version of LocalDispatchKeySet. Declared here just so that
// we can put it in the guards.
struct C10_API PODLocalDispatchKeySet {
uint64_t included_;
uint64_t excluded_;
DispatchKeySet included() const {
return DispatchKeySet(DispatchKeySet::RAW, included_);
}
DispatchKeySet excluded() const {
return DispatchKeySet(DispatchKeySet::RAW, excluded_);
}
void set_included(DispatchKeySet x) {
included_ = x.raw_repr();
}
void set_excluded(DispatchKeySet x) {
excluded_ = x.raw_repr();
}
};
static_assert(std::is_pod<PODLocalDispatchKeySet>::value, "PODLocalDispatchKeySet must be a POD type.");
struct C10_API LocalDispatchKeySet {
/* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x)
: included_(x.included()), excluded_(x.excluded()) {}
DispatchKeySet included_;
DispatchKeySet excluded_;
};
C10_API LocalDispatchKeySet tls_local_dispatch_key_set();
// Internal, use ThreadLocalStateGuard
C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set);
// RAII API for manipulating the thread-local dispatch state.
class C10_API IncludeDispatchKeyGuard {
public:
IncludeDispatchKeyGuard(DispatchKey);
IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete;
IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete;
IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete;
IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete;
~IncludeDispatchKeyGuard();
private:
// A little micro-optimization to save us from tls_get_addr call
// on destruction
PODLocalDispatchKeySet* tls_;
DispatchKey id_;
bool prev_state_;
};
class C10_API ExcludeDispatchKeyGuard {
public:
ExcludeDispatchKeyGuard(DispatchKey);
ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete;
ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete;
ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete;
ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete;
~ExcludeDispatchKeyGuard();
private:
// A little micro-optimization to save us from tls_get_addr call
// on destruction
PODLocalDispatchKeySet* tls_;
DispatchKey id_;
bool prev_state_;
};
// Non-RAII API for manipulating the thread-local dispatch state.
// Please prefer the RAII API. The non-RAII API may be useful when
// the included/excluded state of a given DispatchKey must span
// many calls from the Python to the C++, so you cannot conveniently
// use an RAII guard.
//
// Example use case: a Python context manager that includes a certain
// DispatchKey, to ensure ops running under the context manager dispatch
// through that DispatchKey's registered overrides.
//
// The non-RAII API is less efficient than the RAII guards because both the
// getter and setter will do a tls_getaddr lookup (the RAII struct only needs one!)
C10_API bool tls_is_dispatch_key_excluded(DispatchKey x);
C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state);
C10_API bool tls_is_dispatch_key_included(DispatchKey x);
C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state);
}} // namespace c10::impl