mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36240 It's annoying, historical, and unnecessary (enum class is already namespaced). I did this codemod with: ``` git grep -l 'CPUTensorId' | xargs sed -i 's/CPUTensorId/CPU/g' git grep -l 'CUDATensorId' | xargs sed -i 's/CUDATensorId/CUDA/g' git grep -l 'VariableTensorId' | xargs sed -i 's/VariableTensorId/Autograd/g' git grep -l 'HIPTensorId' | xargs sed -i 's/HIPTensorId/HIP/g' git grep -l 'MSNPUTensorId' | xargs sed -i 's/MSNPUTensorId/MSNPU/g' git grep -l 'XLATensorId' | xargs sed -i 's/XLATensorId/XLA/g' git grep -l 'PrivateUse1_TensorId' | xargs sed -i 's/PrivateUse1_TensorId/PrivateUse1/g' git grep -l 'PrivateUse2_TensorId' | xargs sed -i 's/PrivateUse2_TensorId/PrivateUse2/g' git grep -l 'PrivateUse3_TensorId' | xargs sed -i 's/PrivateUse3_TensorId/PrivateUse3/g' git grep -l 'AutocastTensorId' | xargs sed -i 's/AutocastTensorId/Autocast/g' git grep -l '_PreAutogradTensorId' | xargs sed -i 's/_PreAutogradTensorId/_PreAutograd/g' git grep -l 'TESTING_ONLY_GenericWrapperTensorId' | xargs sed -i 's/TESTING_ONLY_GenericWrapperTensorId/TESTING_ONLY_GenericWrapper/g' git grep -l 'TESTING_ONLY_GenericModeTensorId' | xargs sed -i 's/TESTING_ONLY_GenericModeTensorId/TESTING_ONLY_GenericMode/g' ``` Then I did a git grep for remaining TensorId occurrences, and manually killed those (mostly in codegen, and some docs that needed updating). Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D20929255 Pulled By: ezyang fbshipit-source-id: dc371b6aa6e6ea7c0a5660137c14debde806a09d
115 lines
4.1 KiB
C++
115 lines
4.1 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/DispatchKeySet.h>
|
|
#include <c10/util/Flags.h>
|
|
|
|
// TLS management for DispatchKeySet (the "local" DispatchKeySet(s))
|
|
//
|
|
// This manages two thread-local DispatchKeySets:
|
|
//
|
|
// - The included type set, which adds a tensor type for consideration
|
|
// in dispatch. (For example, you might add Profiling to
|
|
// the included type set to turn on profiling on all tensor operations.)
|
|
//
|
|
// - The excluded type set, which disqualifies a tensor type from dispatch.
|
|
// (For example, after redispatching on variable, we disqualify
|
|
// Autograd so we don't attempt to handle variable again.)
|
|
// (Exclusion wins over inclusion.)
|
|
//
|
|
// NB: Originally, I implemented the excluded type set as storing the inverted
|
|
// set, but TLS is defined to be zero-initialized, so this doesn't actually work
|
|
// (if it's inverted, you want the set to be -1 initialized).
|
|
|
|
namespace c10 {
|
|
namespace impl {
|
|
|
|
C10_DECLARE_bool(disable_variable_dispatch);
|
|
|
|
// POD version of LocalDispatchKeySet. Declared here just so that
|
|
// we can put it in the guards.
|
|
struct C10_API PODLocalDispatchKeySet {
|
|
uint64_t included_;
|
|
uint64_t excluded_;
|
|
|
|
DispatchKeySet included() const {
|
|
return DispatchKeySet(DispatchKeySet::RAW, included_);
|
|
}
|
|
DispatchKeySet excluded() const {
|
|
return DispatchKeySet(DispatchKeySet::RAW, excluded_);
|
|
}
|
|
|
|
void set_included(DispatchKeySet x) {
|
|
included_ = x.raw_repr();
|
|
}
|
|
void set_excluded(DispatchKeySet x) {
|
|
excluded_ = x.raw_repr();
|
|
}
|
|
};
|
|
static_assert(std::is_pod<PODLocalDispatchKeySet>::value, "PODLocalDispatchKeySet must be a POD type.");
|
|
|
|
struct C10_API LocalDispatchKeySet {
|
|
/* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x)
|
|
: included_(x.included()), excluded_(x.excluded()) {}
|
|
DispatchKeySet included_;
|
|
DispatchKeySet excluded_;
|
|
};
|
|
|
|
C10_API LocalDispatchKeySet tls_local_dispatch_key_set();
|
|
|
|
// Internal, use ThreadLocalStateGuard
|
|
C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set);
|
|
|
|
// RAII API for manipulating the thread-local dispatch state.
|
|
|
|
class C10_API IncludeDispatchKeyGuard {
|
|
public:
|
|
IncludeDispatchKeyGuard(DispatchKey);
|
|
IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete;
|
|
IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete;
|
|
IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete;
|
|
IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete;
|
|
~IncludeDispatchKeyGuard();
|
|
private:
|
|
// A little micro-optimization to save us from tls_get_addr call
|
|
// on destruction
|
|
PODLocalDispatchKeySet* tls_;
|
|
DispatchKey id_;
|
|
bool prev_state_;
|
|
};
|
|
|
|
class C10_API ExcludeDispatchKeyGuard {
|
|
public:
|
|
ExcludeDispatchKeyGuard(DispatchKey);
|
|
ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete;
|
|
ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete;
|
|
ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete;
|
|
ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete;
|
|
~ExcludeDispatchKeyGuard();
|
|
private:
|
|
// A little micro-optimization to save us from tls_get_addr call
|
|
// on destruction
|
|
PODLocalDispatchKeySet* tls_;
|
|
DispatchKey id_;
|
|
bool prev_state_;
|
|
};
|
|
|
|
// Non-RAII API for manipulating the thread-local dispatch state.
|
|
// Please prefer the RAII API. The non-RAII API may be useful when
|
|
// the included/excluded state of a given DispatchKey must span
|
|
// many calls from the Python to the C++, so you cannot conveniently
|
|
// use an RAII guard.
|
|
//
|
|
// Example use case: a Python context manager that includes a certain
|
|
// DispatchKey, to ensure ops running under the context manager dispatch
|
|
// through that DispatchKey's registered overrides.
|
|
//
|
|
// The non-RAII API is less efficient than the RAII guards because both the
|
|
// getter and setter will do a tls_getaddr lookup (the RAII struct only needs one!)
|
|
|
|
C10_API bool tls_is_dispatch_key_excluded(DispatchKey x);
|
|
C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state);
|
|
C10_API bool tls_is_dispatch_key_included(DispatchKey x);
|
|
C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state);
|
|
|
|
}} // namespace c10::impl
|