mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49412 FLAGS_disable_variable_dispatch had to go, but it looks like the only user was some benchmarks anyway. ghstack-source-id: 118669590 Test Plan: Small (order of 0.1% improvement) on Internal benchmarks. Wait for GitHub CI since this was reverted before due to CI break Reviewed By: ezyang Differential Revision: D25547962 fbshipit-source-id: 58424b1da230fdc5d27349af762126a5512fce43
130 lines
4.8 KiB
C++
130 lines
4.8 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/DispatchKeySet.h>
|
|
#include <c10/util/Flags.h>
|
|
|
|
// TLS management for DispatchKeySet (the "local" DispatchKeySet(s))
|
|
//
|
|
// This manages two thread-local DispatchKeySets:
|
|
//
|
|
// - The included type set, which adds a tensor type for consideration
|
|
// in dispatch. (For example, you might add Profiling to
|
|
// the included type set to turn on profiling on all tensor operations.)
|
|
//
|
|
// - The excluded type set, which disqualifies a tensor type from dispatch.
|
|
// (For example, after redispatching on variable, we disqualify
|
|
// Autograd so we don't attempt to handle variable again.)
|
|
// (Exclusion wins over inclusion.)
|
|
//
|
|
// NB: Originally, I implemented the excluded type set as storing the inverted
|
|
// set, but TLS is defined to be zero-initialized, so this doesn't actually work
|
|
// (if it's inverted, you want the set to be -1 initialized).
|
|
|
|
namespace c10 {
|
|
namespace impl {
|
|
|
|
// POD version of LocalDispatchKeySet. Declared here just so that
|
|
// we can put it in the guards.
|
|
struct C10_API PODLocalDispatchKeySet {
|
|
uint64_t included_;
|
|
uint64_t excluded_;
|
|
|
|
DispatchKeySet included() const {
|
|
return DispatchKeySet(DispatchKeySet::RAW, included_);
|
|
}
|
|
DispatchKeySet excluded() const {
|
|
return DispatchKeySet(DispatchKeySet::RAW, excluded_);
|
|
}
|
|
|
|
void set_included(DispatchKeySet x) {
|
|
included_ = x.raw_repr();
|
|
}
|
|
void set_excluded(DispatchKeySet x) {
|
|
excluded_ = x.raw_repr();
|
|
}
|
|
};
|
|
static_assert(std::is_pod<PODLocalDispatchKeySet>::value, "PODLocalDispatchKeySet must be a POD type.");
|
|
|
|
struct C10_API LocalDispatchKeySet {
|
|
/* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x)
|
|
: included_(x.included()), excluded_(x.excluded()) {}
|
|
DispatchKeySet included_;
|
|
DispatchKeySet excluded_;
|
|
};
|
|
|
|
// thread_local variables cannot be C10_API on Windows.
|
|
#ifdef _MSC_VER
|
|
C10_API LocalDispatchKeySet tls_local_dispatch_key_set();
|
|
#else // _MSC_VER
|
|
/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting,
|
|
/// thread_local is not supported.
|
|
#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY
|
|
extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
|
|
#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY)
|
|
extern C10_API PODLocalDispatchKeySet raw_local_dispatch_key_set;
|
|
#endif
|
|
|
|
inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() {
|
|
// Don't let people fiddle with the thread_local directly just
|
|
// because they include this header.
|
|
return raw_local_dispatch_key_set;
|
|
}
|
|
#endif // _MSC_VER
|
|
|
|
// Internal, use ThreadLocalStateGuard
|
|
C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set);
|
|
|
|
// RAII API for manipulating the thread-local dispatch state.
|
|
|
|
class C10_API IncludeDispatchKeyGuard {
|
|
public:
|
|
IncludeDispatchKeyGuard(DispatchKeySet);
|
|
IncludeDispatchKeyGuard(DispatchKey k) : IncludeDispatchKeyGuard(DispatchKeySet(k)) {}
|
|
IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete;
|
|
IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete;
|
|
IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete;
|
|
IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete;
|
|
~IncludeDispatchKeyGuard();
|
|
private:
|
|
// A little micro-optimization to save us from tls_get_addr call
|
|
// on destruction
|
|
PODLocalDispatchKeySet* tls_;
|
|
DispatchKeySet include_;
|
|
};
|
|
|
|
class C10_API ExcludeDispatchKeyGuard {
|
|
public:
|
|
ExcludeDispatchKeyGuard(DispatchKeySet);
|
|
ExcludeDispatchKeyGuard(DispatchKey k) : ExcludeDispatchKeyGuard(DispatchKeySet(k)) {}
|
|
ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete;
|
|
ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete;
|
|
ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete;
|
|
ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete;
|
|
~ExcludeDispatchKeyGuard();
|
|
private:
|
|
// A little micro-optimization to save us from tls_get_addr call
|
|
// on destruction
|
|
PODLocalDispatchKeySet* tls_;
|
|
DispatchKeySet exclude_;
|
|
};
|
|
|
|
// Non-RAII API for manipulating the thread-local dispatch state.
|
|
// Please prefer the RAII API. The non-RAII API may be useful when
|
|
// the included/excluded state of a given DispatchKey must span
|
|
// many calls from the Python to the C++, so you cannot conveniently
|
|
// use an RAII guard.
|
|
//
|
|
// Example use case: a Python context manager that includes a certain
|
|
// DispatchKey, to ensure ops running under the context manager dispatch
|
|
// through that DispatchKey's registered overrides.
|
|
//
|
|
// The non-RAII API is less efficient than the RAII guards because both the
|
|
// getter and setter will do a tls_getaddr lookup (the RAII struct only needs one!)
|
|
|
|
C10_API bool tls_is_dispatch_key_excluded(DispatchKey x);
|
|
C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state);
|
|
C10_API bool tls_is_dispatch_key_included(DispatchKey x);
|
|
C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state);
|
|
|
|
}} // namespace c10::impl
|