free up dispatch key space (in C++) (#72402)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72402

The original PR had an array-out-of-bounds access in `DispatchKeyExtractor.cpp`, that wasn't caught by ASAN and appeared to only manifest in a subset of android internal tests. After fixing the OOB access (and adding more asserts), I confirmed that the android internal test passes.

Reland of D33255193 (20b8653dfa)
ghstack-source-id: 148830728

Test Plan:
Steps to test:

(1) connect to a mobile OD

(2) run `one_world android emulator android-29` in a terminal to start the android emulator

(3) In a separate terminal, run the test: `buck test //fbandroid/instrumentation_tests/com/facebook/pytorch/bi_xray:instrumentation_test -c test.external_runner=tpx -- --regex 'testBIXRayModel.*PyTorchBIXRayInstrumentationTest' --force-remote-execution --run-disabled`

I also ran `buck test fbandroid/mode/dbg //fbandroid/instrumentation_tests/com/facebook/pytorch/bi_xray:instrumentation_test`, which failed before and passed after the PR.

Reviewed By: albanD

Differential Revision: D34034848

fbshipit-source-id: 9677ee2c0a1afd1183896f7055009445712523c5
This commit is contained in:
Brian Hirsh 2022-02-14 07:53:38 -08:00 committed by Facebook GitHub Bot
parent 6e986f91a9
commit 9ab9b12d35
20 changed files with 1748 additions and 515 deletions

View File

@ -28,8 +28,7 @@ constexpr auto kFunctorchWrappedTensors = DispatchKeySet({
constexpr auto kTensorSubclassLike = kFunctorchWrappedTensors | DispatchKeySet({ constexpr auto kTensorSubclassLike = kFunctorchWrappedTensors | DispatchKeySet({
DispatchKey::Batched, DispatchKey::Batched,
DispatchKey::SparseCPU, DispatchKey::Sparse,
DispatchKey::SparseCUDA,
DispatchKey::SparseCsrCPU, DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA, DispatchKey::SparseCsrCUDA,
DispatchKey::Meta, DispatchKey::Meta,

View File

@ -43,7 +43,6 @@ inline bool variable_excluded_from_dispatch() {
// Please read the comment in `VariableFallbackKernel.cpp` about the background of this change. // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
return true; return true;
#else #else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::impl::tls_local_dispatch_key_set().excluded_.has(DispatchKey::Autograd));
return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset); return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
#endif #endif
} }

View File

@ -6,11 +6,52 @@
namespace c10 { namespace c10 {
void DispatchKeyExtractor::setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough) { void DispatchKeyExtractor::setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough) {
// (1) update nonFallthroughKeys_
if (has_fallthrough) { if (has_fallthrough) {
nonFallthroughKeys_ = nonFallthroughKeys_.remove(k); nonFallthroughKeys_ = nonFallthroughKeys_.remove(k);
} else { } else {
nonFallthroughKeys_ = nonFallthroughKeys_.add(k); nonFallthroughKeys_ = nonFallthroughKeys_.add(k);
} }
// (2) update nonFallthroughKeysPerBackend_
if (isPerBackendFunctionalityKey(toFunctionalityKey(k))) {
// This is a per-backend functionality key.
// We need to figure out what the current backend is,
// and only update the bitset for that backend.
// subtracting 1 because the first backend should have index 0 (CPU),
// But the enum starts with BackendComponent::InvalidBit.
auto backend_idx = static_cast<uint8_t>(toBackendComponent(k)) - 1;
TORCH_INTERNAL_ASSERT(backend_idx >= 0 && static_cast<uint8_t>(backend_idx) < nonFallthroughKeysPerBackend_.size());
if (has_fallthrough) {
nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].remove(k);
} else {
nonFallthroughKeysPerBackend_[backend_idx] = nonFallthroughKeysPerBackend_[backend_idx].add(k);
}
// Set requiresBitsetPerBackend_ accordingly
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size() - 1)) {
if (nonFallthroughKeysPerBackend_[i] != nonFallthroughKeysPerBackend_[i+1]) {
requiresBitsetPerBackend_ = true;
return;
}
}
requiresBitsetPerBackend_ = false;
return;
} else {
// Otherwise, if a fallthrough is set for a functionality that isn't per backend,
// Then we update the fallthrough bitset for EVERY backend.
// TODO: we could probably optimize this by only lazily updating these values
// the first time that we see requiresBitsetPerBackend_ = true
// (which should almost never happen)
if (has_fallthrough) {
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].remove(k);
}
} else {
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
nonFallthroughKeysPerBackend_[i] = nonFallthroughKeysPerBackend_[i].add(k);
}
}
}
} }
std::string DispatchKeyExtractor::dumpState() const { std::string DispatchKeyExtractor::dumpState() const {

View File

@ -156,15 +156,25 @@ public:
} }
}); });
// Keys that are fallthrough should be skipped // Keys that are fallthrough should be skipped
if (requiresBitsetPerBackend_) {
auto backend_idx = ks.getBackendIndex();
return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
} else {
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
} }
}
template<class... Args> template<class... Args>
DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const { DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
auto ks = detail::multi_dispatch_key_set(args...); auto ks = detail::multi_dispatch_key_set(args...);
// Keys that are fallthrough should be skipped // Keys that are fallthrough should be skipped
if (requiresBitsetPerBackend_) {
auto backend_idx = ks.getBackendIndex();
return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
} else {
return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
} }
}
void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough); void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
@ -193,7 +203,12 @@ private:
explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse) explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
: dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse) : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse)
, nonFallthroughKeys_(DispatchKeySet::FULL) {} , nonFallthroughKeys_(DispatchKeySet::FULL)
, requiresBitsetPerBackend_(false) {
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
}
}
// this is a bitset that has ones for each argument index which has to be // this is a bitset that has ones for each argument index which has to be
// considered for dispatch. This avoids having to iterate over the stack // considered for dispatch. This avoids having to iterate over the stack
@ -205,8 +220,14 @@ private:
// fallthrough // fallthrough
c10::utils::bitset dispatch_arg_indices_reverse_; c10::utils::bitset dispatch_arg_indices_reverse_;
// Set of keys for which the operator does NOT have fallthrough kernel. // Set of functionality keys for which the operator does NOT have fallthrough kernel.
DispatchKeySet nonFallthroughKeys_; DispatchKeySet nonFallthroughKeys_;
// Set of functionality keys for which the operator does NOT have fallthrough kernel, defined PER BACKEND.
// This is only needed if we know that the operator has a different set of fallthroughs defined for some backends.
std::array<DispatchKeySet, num_backends> nonFallthroughKeysPerBackend_;
// Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast path),
// or if we need to fall back to the slower path and check nonFallthroughKeysPerBackend_
bool requiresBitsetPerBackend_;
}; };
} }

View File

@ -267,14 +267,15 @@ void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name)
RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, KernelFunction kernel, std::string debug) { RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, KernelFunction kernel, std::string debug) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
TORCH_CHECK( TORCH_CHECK(
!backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)].kernel.isValid(), !backendFallbackKernels_[idx].kernel.isValid(),
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ", "Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)].debug, ", new registration ", debug backendFallbackKernels_[idx].debug, ", new registration ", debug
); );
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks // NB: inferred function schema is always nullptr for fallbacks, as fallbacks
// cannot be unobxed // cannot be unobxed
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug)); backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
for (auto& op : operators_) { for (auto& op : operators_) {
op.op.updateFallback(*this, dispatchKey); op.op.updateFallback(*this, dispatchKey);
@ -288,7 +289,8 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
void Dispatcher::deregisterFallback_(DispatchKey dispatchKey) { void Dispatcher::deregisterFallback_(DispatchKey dispatchKey) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
backendFallbackKernels_[static_cast<uint8_t>(dispatchKey)] = {}; auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
backendFallbackKernels_[idx] = {};
for (auto& op : operators_) { for (auto& op : operators_) {
op.op.updateFallback(*this, dispatchKey); op.op.updateFallback(*this, dispatchKey);

View File

@ -291,7 +291,7 @@ private:
// Map from namespace to debug string (saying, e.g., where the library was defined) // Map from namespace to debug string (saying, e.g., where the library was defined)
ska::flat_hash_map<std::string, std::string> libraries_; ska::flat_hash_map<std::string, std::string> libraries_;
std::array<impl::AnnotatedKernel, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> backendFallbackKernels_; std::array<impl::AnnotatedKernel, num_runtime_entries> backendFallbackKernels_;
std::unique_ptr<detail::RegistrationListenerList> listeners_; std::unique_ptr<detail::RegistrationListenerList> listeners_;
std::mutex mutex_; std::mutex mutex_;
@ -531,8 +531,7 @@ C10_DISPATCHER_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorH
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor() auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
.template getDispatchKeySetUnboxed<Args...>(args...); .template getDispatchKeySetUnboxed<Args...>(args...);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKeySet.highestPriorityTypeId())); const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet.highestPriorityTypeId());
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// By default, when there're no high-frequency or non-sampled callbacks, // By default, when there're no high-frequency or non-sampled callbacks,
// RecordFunction is pre-sampled as a perf optimization; // RecordFunction is pre-sampled as a perf optimization;
@ -553,7 +552,7 @@ template<class Return, class... Args>
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const { inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
// do not use RecordFunction on redispatch // do not use RecordFunction on redispatch
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet.highestPriorityTypeId()); const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet);
return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...); return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
} }
@ -561,7 +560,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
// note: this doesn't need the mutex because write operations on the list keep iterators intact. // note: this doesn't need the mutex because write operations on the list keep iterators intact.
const auto& entry = op.operatorDef_->op; const auto& entry = op.operatorDef_->op;
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
const auto& kernel = entry.lookup(dispatchKeySet.highestPriorityTypeId()); const auto& kernel = entry.lookup(dispatchKeySet);
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
bool pre_sampled = false; bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) { if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
@ -593,7 +592,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const { inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const {
// note: this doesn't need the mutex because write operations on the list keep iterators intact. // note: this doesn't need the mutex because write operations on the list keep iterators intact.
const auto& entry = op.operatorDef_->op; const auto& entry = op.operatorDef_->op;
const auto& kernel = entry.lookup(dispatchKeySet.highestPriorityTypeId()); const auto& kernel = entry.lookup(dispatchKeySet);
return kernel.callBoxed(op, dispatchKeySet, stack); return kernel.callBoxed(op, dispatchKeySet, stack);
} }

View File

@ -283,7 +283,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
} }
// 3. Backend fallback // 3. Backend fallback
auto dispatch_ix = static_cast<uint8_t>(dispatch_key); auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) { if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) {
return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"}; return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"};
} }
@ -299,10 +299,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// or alias keys and their associated keysets). // or alias keys and their associated keysets).
// This function should be considered a private helper for updateDispatchTable_() // This function should be considered a private helper for updateDispatchTable_()
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) { void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
const auto dispatch_ix = c10::getDispatchTableIndexForDispatchKey(dispatch_key); const auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key);
if (C10_UNLIKELY(dispatch_ix == -1)) {
return;
}
dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key); dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough()); dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
} }
@ -329,8 +326,12 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
} }
// Note [Refresh Runtime Autograd entries in dispatchTable_] // Note [Refresh Runtime Autograd entries in dispatchTable_]
// Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3). // Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
// In theory, we should only have to check if the given runtime key has "dense" functionality,
// e.g. DispatchKey::CPU (which is composed of DispatchKey::Dense and BackendComponent::CPUBit).
// However, there are some backends that should be included in this set that don't have the dense key set.
// E.g. DispatchKey::Meta, DispatchKey::ORT.
if (c10::isBackendDispatchKey(dispatch_key)) { if (c10::isBackendDispatchKey(dispatch_key)) {
DispatchKey autograd_key = getAutogradKeyFromBackend(dispatch_key); DispatchKey autograd_key = getAutogradKeyFromBackend(toBackendComponent(dispatch_key));
updateDispatchTableEntry_(dispatcher, autograd_key); updateDispatchTableEntry_(dispatcher, autograd_key);
} }
} }
@ -357,8 +358,9 @@ void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher)
// catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either CompositeExplicitAutograd // catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either CompositeExplicitAutograd
// or CompositeImplicitAutograd alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, CompositeImplicitAutograd) // or CompositeImplicitAutograd alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, CompositeImplicitAutograd)
// should return true, it returns false because Undefined cannot be represented in a DispatchKeySet. // should return true, it returns false because Undefined cannot be represented in a DispatchKeySet.
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) { updateDispatchTable_(dispatcher, DispatchKey::Undefined);
updateDispatchTable_(dispatcher, static_cast<DispatchKey>(iter)); for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
updateDispatchTable_(dispatcher, k);
} }
} }
@ -371,9 +373,10 @@ void OperatorEntry::checkInvariants() const {
for (const auto& kv : kernels_) { for (const auto& kv : kernels_) {
TORCH_INTERNAL_ASSERT(kv.second.size() > 0, dumpState()); TORCH_INTERNAL_ASSERT(kv.second.size() > 0, dumpState());
} }
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) { for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), static_cast<DispatchKey>(iter)); auto expected_k = computeDispatchTableEntry(c10::Dispatcher::singleton(), k);
TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[iter]), auto idx = getDispatchTableIndexForDispatchKey(k);
TORCH_INTERNAL_ASSERT(expected_k._equalsBoxedAndUnboxed(dispatchTable_[idx]),
"Canonical state\n~~~~~~~~~~~\n", dumpState(), "\n\n" "Canonical state\n~~~~~~~~~~~\n", dumpState(), "\n\n"
"Computed table:\n~~~~~~~~~~~\n", dumpComputedTable()); "Computed table:\n~~~~~~~~~~~\n", dumpComputedTable());
} }
@ -384,7 +387,8 @@ std::string OperatorEntry::listAllDispatchKeys() const {
str << "["; str << "[";
bool has_kernels = false; bool has_kernels = false;
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) { for (auto k : DispatchKeySet(DispatchKeySet::FULL)) {
auto iter = getDispatchTableIndexForDispatchKey(k);
if (!dispatchTable_[iter].isValid()) { if (!dispatchTable_[iter].isValid()) {
continue; continue;
} }
@ -443,8 +447,12 @@ void OperatorEntry::reportError(DispatchKey dispatchKey) const {
// updateDispatchTableFull_ would update the dispatch table to be) // updateDispatchTableFull_ would update the dispatch table to be)
std::string OperatorEntry::dumpComputedTable() const { std::string OperatorEntry::dumpComputedTable() const {
std::ostringstream oss; std::ostringstream oss;
for (uint8_t i = 0; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); i++) { // Need to handle Undefined separately, because its a runtime key that can't be represented
auto k = static_cast<DispatchKey>(i); // in a DispatchKeySet.
std::vector<DispatchKey> runtime_keys = {DispatchKey::Undefined};
for (auto k : DispatchKeySet(DispatchKeySet::FULL)) runtime_keys.push_back(k);
for (auto k : runtime_keys) {
auto kernel_prov = computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k); auto kernel_prov = computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);
if (kernel_prov.first.kernel.isValid()) { if (kernel_prov.first.kernel.isValid()) {
oss << toString(k) << ": " oss << toString(k) << ": "

View File

@ -173,11 +173,8 @@ public:
[[noreturn]] void reportError(DispatchKey dispatchKey) const; [[noreturn]] void reportError(DispatchKey dispatchKey) const;
const KernelFunction& lookup(DispatchKey k) const { const KernelFunction& lookup(DispatchKeySet ks) const {
const auto idx = getDispatchTableIndexForDispatchKey(k); const auto idx = ks.getDispatchTableIndexForDispatchKeySet();
if (C10_UNLIKELY(idx == -1)) {
reportError(k);
}
const auto& kernel = dispatchTable_[idx]; const auto& kernel = dispatchTable_[idx];
// A valid kernel *always* has a boxed kernel and *may* have an // A valid kernel *always* has a boxed kernel and *may* have an
// unboxed kernel. However, we typically do unboxed calls in at:: // unboxed kernel. However, we typically do unboxed calls in at::
@ -187,7 +184,7 @@ public:
// in the common case. // in the common case.
if (C10_UNLIKELY(!kernel.isValidUnboxed())) { if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
if (!kernel.isValid()) { if (!kernel.isValid()) {
reportError(k); reportError(ks.highestPriorityTypeId());
} }
} }
return kernel; return kernel;
@ -211,7 +208,7 @@ private:
OperatorName name_; OperatorName name_;
c10::optional<AnnotatedSchema> schema_; c10::optional<AnnotatedSchema> schema_;
std::array<KernelFunction, c10::getDispatchTableIndexForDispatchKey(DispatchKey::NumDispatchKeys)> dispatchTable_; std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
DispatchKeyExtractor dispatchKeyExtractor_; DispatchKeyExtractor dispatchKeyExtractor_;
// kernels_ stores all registered kernels for the corresponding dispatch key // kernels_ stores all registered kernels for the corresponding dispatch key

View File

@ -591,7 +591,7 @@ TEST(OperatorRegistrationTest, AutogradBackendOverridesAutogradKernel) {
void LazyBackendsAutogradOverridesAutogradKernel(DispatchKey key) { void LazyBackendsAutogradOverridesAutogradKernel(DispatchKey key) {
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
.kernel<decltype(nonautograd_kernel), &nonautograd_kernel>(c10::getAutogradKeyFromBackend(key)) .kernel<decltype(nonautograd_kernel), &nonautograd_kernel>(c10::getAutogradKeyFromBackend(toBackendComponent(key)))
.kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd)); .kernel<decltype(autograd_kernel), &autograd_kernel>(DispatchKey::Autograd));
auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
@ -1791,22 +1791,22 @@ TEST(NewOperatorRegistrationTest, dispatchAutogradPrecedence) {
TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther) { TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables) // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool sparsecpu_called, math_called = false; bool fpga_called, math_called = false;
auto m = MAKE_TORCH_LIBRARY(test); auto m = MAKE_TORCH_LIBRARY(test);
m.def("fn", torch::dispatch(c10::DispatchKey::SparseCPU, [&](const Tensor& x) { sparsecpu_called = true; return x; })); m.def("fn", torch::dispatch(c10::DispatchKey::FPGA, [&](const Tensor& x) { fpga_called = true; return x; }));
m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }); m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; });
auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
ASSERT_TRUE(op.has_value()); ASSERT_TRUE(op.has_value());
{ {
callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU)); callOp(*op, dummyTensor(c10::DispatchKey::FPGA));
ASSERT_TRUE(sparsecpu_called); ASSERT_TRUE(fpga_called);
} }
{ {
expectThrows<c10::Error>([&] { expectThrows<c10::Error>([&] {
callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true)); callOp(*op, dummyTensor(c10::DispatchKey::FPGA, /*requires_grad=*/true));
}, "test::fn has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther."); }, "test::fn has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther.");
} }
} }
@ -1849,18 +1849,15 @@ TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) {
} }
{ {
// TODO(#43908): currently this will fallthrough AutogradPrivateUse1 then call catchall kernel
// at AutogradCPU, while backend extenders are indeed expecting to call PrivateUse1 kernel.
// This confusing behavior is caused by we registering fallthrough as backend fallback for
// Autograd keys. Note users could always work around this by registering the same kernel to
// AutogradPrivateUse1 as shown below until we support it.
auto op = Dispatcher::singleton().findOp({"test::fn", ""}); auto op = Dispatcher::singleton().findOp({"test::fn", ""});
ASSERT_TRUE(op.has_value()); ASSERT_TRUE(op.has_value());
catchall_called = false; catchall_called = false;
privateuse1_called = false;
callOp(*op, callOp(*op,
dummyTensor(c10::DispatchKey::PrivateUse1, /*requires_grad=*/true), dummyTensor(c10::DispatchKey::PrivateUse1, /*requires_grad=*/true),
dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true)); dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
ASSERT_TRUE(catchall_called); ASSERT_FALSE(catchall_called);
ASSERT_TRUE(privateuse1_called);
} }
m.impl("fn", c10::DispatchKey::AutogradPrivateUse1, [&](const Tensor& x, const Tensor& y) { privateuse1_called = true; return x; }); m.impl("fn", c10::DispatchKey::AutogradPrivateUse1, [&](const Tensor& x, const Tensor& y) { privateuse1_called = true; return x; });
@ -1876,6 +1873,27 @@ TEST(NewOperatorRegistrationTest, dispatchMultipleTensors) {
} }
} }
TEST(NewOperatorRegistrationTest, registerCompositeImplicitAutogradWithCPUKernel_andCallAutogradOtherKernel_callsComposite) {
bool math_called = false;
bool cpu_called = false;
auto m = MAKE_TORCH_LIBRARY(test);
m.def("fn(Tensor dummy) -> Tensor");
m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { cpu_called = true; return x; });
m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; });
auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
ASSERT_TRUE(op.has_value());
{
math_called = cpu_called = false;
// Meta should redispatch to the AutogradOther backend,
// which the composite kernel should be registered to.
callOp(*op, dummyTensor(c10::DispatchKey::Meta, /*requires_grad=*/true));
ASSERT_TRUE(math_called);
ASSERT_FALSE(cpu_called);
}
}
TEST(NewOperatorRegistrationTest, dispatchMultiple) { TEST(NewOperatorRegistrationTest, dispatchMultiple) {
bool cpu_called = false; bool cpu_called = false;
bool cuda_called = false; bool cuda_called = false;

View File

@ -1,14 +1,47 @@
#include <c10/core/DispatchKey.h> #include <c10/core/DispatchKey.h>
#include <c10/core/DispatchKeySet.h>
#include <unordered_map> #include <unordered_map>
namespace c10 { namespace c10 {
const char* toString(BackendComponent t) {
switch (t) {
case BackendComponent::CPUBit:
return "CPUBit";
case BackendComponent::CUDABit:
return "CUDABit";
case BackendComponent::HIPBit:
return "HIPBit";
case BackendComponent::XLABit:
return "XLABit";
case BackendComponent::LazyBit:
return "LazyBit";
case BackendComponent::XPUBit:
return "XPUBit";
case BackendComponent::MLCBit:
return "MLCBit";
case BackendComponent::HPUBit:
return "HPUBit";
case BackendComponent::VEBit:
return "VEBit";
case BackendComponent::PrivateUse1Bit:
return "PrivateUse1Bit";
case BackendComponent::PrivateUse2Bit:
return "PrivateUse2Bit";
case BackendComponent::PrivateUse3Bit:
return "PrivateUse3Bit";
case BackendComponent::InvalidBit:
return "InvalidBit";
default:
return "UNKNOWN_BACKEND_BIT";
}
}
const char* toString(DispatchKey t) { const char* toString(DispatchKey t) {
switch (t) { switch (t) {
case DispatchKey::Undefined: case DispatchKey::Undefined:
return "Undefined"; return "Undefined";
case DispatchKey::CPU: case DispatchKey::CPU:
return "CPU"; return "CPU";
case DispatchKey::CUDA: case DispatchKey::CUDA:
@ -101,8 +134,6 @@ const char* toString(DispatchKey t) {
return "AutogradMLC"; return "AutogradMLC";
case DispatchKey::AutogradHPU: case DispatchKey::AutogradHPU:
return "AutogradHPU"; return "AutogradHPU";
case DispatchKey::AutogradNestedTensor:
return "AutogradNestedTensor";
case DispatchKey::AutogradPrivateUse1: case DispatchKey::AutogradPrivateUse1:
return "AutogradPrivateUse1"; return "AutogradPrivateUse1";
case DispatchKey::AutogradPrivateUse2: case DispatchKey::AutogradPrivateUse2:
@ -111,6 +142,8 @@ const char* toString(DispatchKey t) {
return "AutogradPrivateUse3"; return "AutogradPrivateUse3";
case DispatchKey::AutogradOther: case DispatchKey::AutogradOther:
return "AutogradOther"; return "AutogradOther";
case DispatchKey::AutogradNestedTensor:
return "AutogradNestedTensor";
case DispatchKey::ZeroTensor: case DispatchKey::ZeroTensor:
return "ZeroTensor"; return "ZeroTensor";
@ -168,6 +201,15 @@ const char* toString(DispatchKey t) {
case DispatchKey::FuncTorchBatched: case DispatchKey::FuncTorchBatched:
return "FuncTorchBatched"; return "FuncTorchBatched";
case DispatchKey::Dense:
return "Dense";
case DispatchKey::Quantized:
return "Quantized";
case DispatchKey::Sparse:
return "Sparse";
case DispatchKey::AutogradFunctionality:
return "AutogradFunctionality";
default: default:
return "UNKNOWN_TENSOR_TYPE_ID"; return "UNKNOWN_TENSOR_TYPE_ID";
} }
@ -176,76 +218,37 @@ const char* toString(DispatchKey t) {
std::ostream& operator<<(std::ostream& str, DispatchKey rhs) { std::ostream& operator<<(std::ostream& str, DispatchKey rhs) {
return str << toString(rhs); return str << toString(rhs);
} }
std::ostream& operator<<(std::ostream& str, BackendComponent rhs) {
// for a given backend key, return the associated autograd key. return str << toString(rhs);
// for non-backend keys, return AutogradOther as a default.
// Note: it's convenient and fast to return a default here rather than (say)
// returning an optional<DispatchKey>, or throwing. But it makes callers
// responsible for either a) enforcing the invariant that only backend keys
// be passed as arguments, or b) interpreting our return value carefully.
//
DispatchKey getAutogradKeyFromBackend(DispatchKey t) {
switch (t) {
case DispatchKey::CPU:
return DispatchKey::AutogradCPU;
case DispatchKey::XPU:
return DispatchKey::AutogradXPU;
case DispatchKey::CUDA:
return DispatchKey::AutogradCUDA;
case DispatchKey::XLA:
return DispatchKey::AutogradXLA;
case DispatchKey::Lazy:
return DispatchKey::AutogradLazy;
case DispatchKey::MLC:
return DispatchKey::AutogradMLC;
case DispatchKey::HPU:
return DispatchKey::AutogradHPU;
case DispatchKey::NestedTensor:
return DispatchKey::AutogradNestedTensor;
case DispatchKey::PrivateUse1:
return DispatchKey::AutogradPrivateUse1;
case DispatchKey::PrivateUse2:
return DispatchKey::AutogradPrivateUse2;
case DispatchKey::PrivateUse3:
return DispatchKey::AutogradPrivateUse3;
default:
return DispatchKey::AutogradOther;
} }
DispatchKey getAutogradKeyFromBackend(BackendComponent k) {
// We want this to return an autograd key. We're relying on the fact that
// getAutogradRelatedKeySetFromBackend returns an autograd key +
// ADInplaceOrView, and autograd has higher precedence. The core mapping from
// backend -> autograd key lives in `getAutogradRelatedKeySetFromBackend`
// instead of here for performance. `getAutogradRelatedKeySetFromBackend` is a
// hotpath function, and we want to make sure that it doesn't have to
// construct any DispatchKeySets at runtime.
return getAutogradRelatedKeySetFromBackend(k).highestPriorityTypeId();
} }
c10::DispatchKey parseDispatchKey(const std::string& k) { c10::DispatchKey parseDispatchKey(const std::string& k) {
static std::unordered_map<std::string, c10::DispatchKey> key_map = { static std::unordered_map<std::string, c10::DispatchKey> key_map = {
{"Undefined", c10::DispatchKey::Undefined}, {"Undefined", c10::DispatchKey::Undefined},
{"CPU", c10::DispatchKey::CPU}, {"Dense", c10::DispatchKey::Dense},
{"CUDA", c10::DispatchKey::CUDA},
{"HIP", c10::DispatchKey::HIP},
{"FPGA", c10::DispatchKey::FPGA}, {"FPGA", c10::DispatchKey::FPGA},
{"ORT", c10::DispatchKey::ORT}, {"ORT", c10::DispatchKey::ORT},
{"XLA", c10::DispatchKey::XLA},
{"MLC", c10::DispatchKey::MLC},
{"Vulkan", c10::DispatchKey::Vulkan}, {"Vulkan", c10::DispatchKey::Vulkan},
{"Metal", c10::DispatchKey::Metal}, {"Metal", c10::DispatchKey::Metal},
{"XPU", c10::DispatchKey::XPU},
{"HPU", c10::DispatchKey::HPU},
{"VE", c10::DispatchKey::VE}, {"VE", c10::DispatchKey::VE},
{"Lazy", c10::DispatchKey::Lazy},
{"Meta", c10::DispatchKey::Meta}, {"Meta", c10::DispatchKey::Meta},
{"QuantizedCPU", c10::DispatchKey::QuantizedCPU}, {"Quantized", c10::DispatchKey::Quantized},
{"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA},
{"QuantizedXPU", c10::DispatchKey::QuantizedXPU},
{"CustomRNGKeyId", c10::DispatchKey::CustomRNGKeyId}, {"CustomRNGKeyId", c10::DispatchKey::CustomRNGKeyId},
{"MkldnnCPU", c10::DispatchKey::MkldnnCPU}, {"MkldnnCPU", c10::DispatchKey::MkldnnCPU},
{"SparseCPU", c10::DispatchKey::SparseCPU}, {"Sparse", c10::DispatchKey::Sparse},
{"SparseCUDA", c10::DispatchKey::SparseCUDA},
{"SparseHIP", c10::DispatchKey::SparseHIP},
{"SparseXPU", c10::DispatchKey::SparseXPU},
{"SparseVE", c10::DispatchKey::SparseVE},
{"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU}, {"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
{"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA}, {"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
{"NestedTensor", c10::DispatchKey::NestedTensor},
{"PrivateUse1", c10::DispatchKey::PrivateUse1},
{"PrivateUse2", c10::DispatchKey::PrivateUse2},
{"PrivateUse3", c10::DispatchKey::PrivateUse3},
{"BackendSelect", c10::DispatchKey::BackendSelect}, {"BackendSelect", c10::DispatchKey::BackendSelect},
{"Python", c10::DispatchKey::Python}, {"Python", c10::DispatchKey::Python},
{"Named", c10::DispatchKey::Named}, {"Named", c10::DispatchKey::Named},
@ -256,17 +259,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
c10::DispatchKey::FuncTorchDynamicLayerBackMode}, c10::DispatchKey::FuncTorchDynamicLayerBackMode},
{"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView}, {"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView},
{"AutogradOther", c10::DispatchKey::AutogradOther}, {"AutogradOther", c10::DispatchKey::AutogradOther},
{"AutogradCPU", c10::DispatchKey::AutogradCPU}, {"AutogradFunctionality", c10::DispatchKey::AutogradFunctionality},
{"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
{"AutogradXLA", c10::DispatchKey::AutogradXLA},
{"AutogradLazy", c10::DispatchKey::AutogradLazy},
{"AutogradXPU", c10::DispatchKey::AutogradXPU},
{"AutogradMLC", c10::DispatchKey::AutogradMLC},
{"AutogradHPU", c10::DispatchKey::AutogradHPU},
{"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor}, {"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor},
{"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1},
{"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2},
{"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3},
{"Tracer", c10::DispatchKey::Tracer}, {"Tracer", c10::DispatchKey::Tracer},
{"AutocastCPU", c10::DispatchKey::AutocastCPU}, {"AutocastCPU", c10::DispatchKey::AutocastCPU},
{"AutocastCUDA", c10::DispatchKey::AutocastCUDA}, {"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
@ -280,6 +274,41 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"TESTING_ONLY_GenericWrapper", {"TESTING_ONLY_GenericWrapper",
c10::DispatchKey::TESTING_ONLY_GenericWrapper}, c10::DispatchKey::TESTING_ONLY_GenericWrapper},
{"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode}, {"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
{"CPU", c10::DispatchKey::CPU},
{"CUDA", c10::DispatchKey::CUDA},
{"HIP", c10::DispatchKey::HIP},
{"XLA", c10::DispatchKey::XLA},
{"MLC", c10::DispatchKey::MLC},
{"XPU", c10::DispatchKey::XPU},
{"HPU", c10::DispatchKey::HPU},
{"Lazy", c10::DispatchKey::Lazy},
{"NestedTensor", c10::DispatchKey::NestedTensor},
{"PrivateUse1", c10::DispatchKey::PrivateUse1},
{"PrivateUse2", c10::DispatchKey::PrivateUse2},
{"PrivateUse3", c10::DispatchKey::PrivateUse3},
{"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
{"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA},
{"QuantizedXPU", c10::DispatchKey::QuantizedXPU},
{"SparseCPU", c10::DispatchKey::SparseCPU},
{"SparseCUDA", c10::DispatchKey::SparseCUDA},
{"SparseHIP", c10::DispatchKey::SparseHIP},
{"SparseXPU", c10::DispatchKey::SparseXPU},
{"SparseVE", c10::DispatchKey::SparseVE},
{"AutogradCPU", c10::DispatchKey::AutogradCPU},
{"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
{"AutogradXLA", c10::DispatchKey::AutogradXLA},
{"AutogradLazy", c10::DispatchKey::AutogradLazy},
{"AutogradXPU", c10::DispatchKey::AutogradXPU},
{"AutogradMLC", c10::DispatchKey::AutogradMLC},
{"AutogradHPU", c10::DispatchKey::AutogradHPU},
{"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1},
{"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2},
{"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3},
{"Autograd", c10::DispatchKey::Autograd}, {"Autograd", c10::DispatchKey::Autograd},
{"CompositeImplicitAutograd", {"CompositeImplicitAutograd",
c10::DispatchKey::CompositeImplicitAutograd}, c10::DispatchKey::CompositeImplicitAutograd},

View File

@ -9,20 +9,98 @@
namespace c10 { namespace c10 {
// Semantically, each value of BackendComponent identifies a "backend" for our
// dispatch. Some functionalities that we may dispatch to are allowed to
// register different handlers for each backend. The BackendComponent is then
// used to figure out which backend implementation to dispatch to.
// In implementation terms, the backend component identifies a specific "bit" in
// a DispatchKeySet. The bits in the DispatchKeySet are split between the bottom
// ~12 "BackendComponent" bits, while the remaining upper bits are assigned to
// functionalities. When we encounter a functionality bit that is known to be
// customizeable per-backend, then we also look at the lower BackendComponent
// bits and take the highest bit to determine which backend's implementation to
// use.
enum class BackendComponent : uint8_t {
// A "backend" is colloquially used to refer to handlers for dispatch
// which actually implement the numerics of an operation in question.
//
// Due to the nature of the enum, these backends are specified in
// an ordered way, but for most backends this order is not semantically
// meaningful (e.g., it's valid to reorder these backends without changing
// semantics). The only situation when backend ordering is meaningful
// is when the backend participates in multiple dispatch with another
// backend; e.g., CPU and CUDA (cuda must have higher priority).
// These keys don't correspond to individual kernels.
// Instead, they represent the backends that are allowed to override specific
// pieces of functionality:
// - dense kernels (e.g. DispatchKey::CPU)
// - sparse kernels (e.g. DispatchKey::SparseCPU)
// - quantized kernels (e.g. DispatchKey::QuantizedCPU)
// - autograd kernels (e.g. DispatchKey::AutogradCPU)
// We reserve space in the runtime operator table for this full cross product
// of
// [backends in this enum] x [keys below that are explicitly marked as having
// per-backend functionality]
InvalidBit = 0,
CPUBit,
CUDABit,
HIPBit,
XLABit,
MLCBit,
XPUBit,
HPUBit,
VEBit,
LazyBit,
PrivateUse1Bit,
PrivateUse2Bit,
PrivateUse3Bit,
// Define an alias to represent end of backend dispatch keys.
// If you add new backend keys after PrivateUse3, please also update it here.
// (But you shouldn't: private use keys should have higher precedence than
// all built-in keys)
EndOfBackendKeys = PrivateUse3Bit,
};
// Semantically, a dispatch key identifies a possible "level" in our // Semantically, a dispatch key identifies a possible "level" in our
// dispatch, for which a handler may be registered. Traditional // dispatch, for which a handler may be registered. Each handler corresponds
// backends like CPU and CUDA get dispatch keys; however, so do // to a type of functionality.
// "wrapping" layers like Variable (for autograd handling).
// //
// In implementation terms, the dispatch key identifies a specific "bit" in a // In implementation terms, the dispatch key identifies a specific "bit" in a
// DispatchKeySet. Higher bit indexes get handled by dispatching first (because // DispatchKeySet. Higher bit indexes get handled by dispatching first (because
// we "count leading zeros" when we extract the highest priority dispatch // we "count leading zeros" when we extract the highest priority dispatch
// key.) // key.)
// //
// Note [DispatchKey Classification]
// This enum actually contains several types of keys, which are explained
// in more detail further down:
// (1) non-customizable backends (e.g. FPGA)
// (2) non-customizable functionalities (e.g. Functionalize)
// (3) functionalized that are customizable per backend (e.g. Dense, Sparse,
// AutogradFunctionality) (4) per-backend instances of customizable
// functionalities (e.g. CPU, SparseCPU, AutogradCPU) (5) alias keys (e.g.
// CompositeImplicitAutograd)
//
// Of the categories above, it's important to note:
// (a) which keys are assigned individual bits in a DispatchKeySet
// (b) which keys are assigned individual slots in the runtime operator table
// ("Runtime keys")
//
// (1), (2) and (3) all get their own dedicated bits in the DispatchKeySet.
// (1), (2) and (4) all get their own dedicated slots in the runtime operator
// table.
// See Note [DispatchKeySet Internal Representation] for more details.
//
// NOTE: Keep the list in sync with `DispatchKey` in tools/codegen/model.py // NOTE: Keep the list in sync with `DispatchKey` in tools/codegen/model.py
enum class DispatchKey : uint8_t { enum class DispatchKey : uint16_t {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // // ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// This is not a "real" tensor id, but it exists to give us a "nullopt" // This is not a "real" functionality, but it exists to give us a "nullopt"
// element we can return for cases when a DispatchKeySet contains no elements. // element we can return for cases when a DispatchKeySet contains no elements.
// You can think a more semantically accurate definition of DispatchKey is: // You can think a more semantically accurate definition of DispatchKey is:
// //
@ -38,24 +116,31 @@ enum class DispatchKey : uint8_t {
// this will get eliminated, but for now it's convenient) // this will get eliminated, but for now it's convenient)
CatchAll = Undefined, CatchAll = Undefined,
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ BACKENDS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Functionality Keys ~~~~~~~~~~~~~~~~~~~~~~ //
// A "backend" is colloquially used to refer to handlers for dispatch // Every value in the enum (up to EndOfFunctionalityKeys)
// which actually implement the numerics of an operation in question. // corresponds to an individual "functionality" that can be dispatched to.
// This is represented in the DispatchKeySet by assigning each of these enum
// values
// to each of the remaining (64 - len(BackendComponent)) bits.
// //
// Due to the nature of the enum, these backends are specified in // Most of these functionalities have a single handler assigned to them,
// an ordered way, but for most backends this order is not semantically // making them "runtime keys".
// meaningful (e.g., it's valid to reorder these backends without changing // That map to a single slot in the runtime operator table.
// semantics). The only situation when backend ordering is meaningful //
// is when the backend participates in multiple dispatch with another // A few functionalities are allowed to be customizable per backend.
// backend; e.g., CPU and SparseCPU (sparse must have // See [Note: Per-Backend Functionality Dispatch Keys] for details.
// higher priority).
// See [Note: Per-Backend Functionality Dispatch Keys]
Dense,
// Below are non-extensible backends.
// These are backends that currently don't have their own overrides for
// Autograd/Sparse/Quantized kernels,
// and we therefore don't waste space in the runtime operator table allocating
// space for them.
// If any of these backends ever need to customize, e.g., Autograd, then we'll
// need to add a DispatchKey::*Bit for them.
// Here are backends which you think of as traditionally specifying
// how to implement operations on some device.
CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp
CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp
HIP, // NB: I think this is not actually used, due to Note [Masquerading as
// CUDA]
FPGA, // Xilinx support lives out of tree at FPGA, // Xilinx support lives out of tree at
// https://gitlab.com/pytorch-complex/vitis_kernels // https://gitlab.com/pytorch-complex/vitis_kernels
@ -67,14 +152,8 @@ enum class DispatchKey : uint8_t {
// - aten/src/ATen/test/extension_backend_test.cpp // - aten/src/ATen/test/extension_backend_test.cpp
ORT, ORT,
XLA, // lives out of tree at https://github.com/pytorch/xla
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
Vulkan, Vulkan,
Metal, Metal,
XPU, // For out of tree Intel's heterogeneous computing plug-in
HPU, // For out of tree & closed source integration of HPU / Habana
VE, // For out of tree & closed source integration of SX-Aurora / NEC
Lazy, // For lazy tensor backends
// A meta tensor is a tensor without any data associated with it. (They // A meta tensor is a tensor without any data associated with it. (They
// have also colloquially been referred to as tensors on the "null" device). // have also colloquially been referred to as tensors on the "null" device).
@ -83,11 +162,8 @@ enum class DispatchKey : uint8_t {
// tensor with the output shape and dtype, but wouldn't actually add anything. // tensor with the output shape and dtype, but wouldn't actually add anything.
Meta, Meta,
// Here are backends which specify more specialized operators // See [Note: Per-Backend Functionality Dispatch Keys]
// based on the dtype of the tensor. Quantized,
QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp
QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp
QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in
// This backend is to support custom RNGs; it lets you go // This backend is to support custom RNGs; it lets you go
// to a different kernel if you pass in a generator that is not a // to a different kernel if you pass in a generator that is not a
@ -106,31 +182,29 @@ enum class DispatchKey : uint8_t {
// the corresponding dense tensors, and must be handled before them. // the corresponding dense tensors, and must be handled before them.
MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp
// NB: not to be confused with MKLDNN, which is Caffe2 only // NB: not to be confused with MKLDNN, which is Caffe2 only
SparseCPU, // registered at build/aten/src/ATen/RegisterSparseCPU.cpp
SparseCUDA, // registered at build/aten/src/ATen/RegisterSparseCUDA.cpp // See [Note: Per-Backend Functionality Dispatch Keys]
SparseHIP, // TODO: I think this is not actually used, due to Note Sparse,
// [Masquerading as CUDA]
SparseXPU, // For out of tree Intel's heterogeneous computing plug-in
SparseVE, // For out of tree & closed source integration of SX-Aurora / NEC
SparseCsrCPU, SparseCsrCPU,
SparseCsrCUDA, SparseCsrCUDA,
// Note [Non-Customizable Backend Keys]
// Every key above here is considered a "non-customizable backend".
// These are backends that will work correctly with autograd, but
// but currently don't require separate implementations
// for autograd sparse or quantized kernels.
// Any new backends that don't need to be customized should go above here.
// If an existing backend needs to e.g. override autograd, then we can
// consider promoting it into the "BackendComponent" enum
//
// For all intents and purposes from the perspective of DispatchKeySet,
// "non-customizable backend" keys are treated the same way
// as other functionality keys
EndOfNonCustomizableBackends = SparseCsrCUDA,
NestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor NestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
// Here are reserved backends for user-defined backends, see Note [Private use
// DispatchKey]
// To see some example about how to use this, check out ORT
PrivateUse1,
PrivateUse2,
PrivateUse3,
// Define an alias key to represent end of backend dispatch keys.
// If you add new backend keys after PrivateUse3, please also update it here.
// (But you shouldn't: private use keys should have higher precedence than
// all built-in keys)
EndOfBackendKeys = PrivateUse3,
// In some situations, it is not immediately obvious what the correct // In some situations, it is not immediately obvious what the correct
// backend for function is, because the function in question doesn't // backend for function is, because the function in question doesn't
// have any "tensor" arguments. In this case, a BackendSelect function // have any "tensor" arguments. In this case, a BackendSelect function
@ -233,20 +307,18 @@ enum class DispatchKey : uint8_t {
// AutogradOther key. We can add specific autograd key for those backends // AutogradOther key. We can add specific autograd key for those backends
// upon request. // upon request.
AutogradOther, AutogradOther,
AutogradCPU,
AutogradCUDA, // See [Note: Per-Backend Functionality Dispatch Keys]
AutogradXLA, AutogradFunctionality,
AutogradLazy,
AutogradXPU, // NestedTensor is an example of something that isn't a "real backend"
AutogradMLC, // (because it mostly consists of redispatching kernels)
AutogradHPU, // but it would like to override autograd functionality in C++.
AutogradNestedTensor, // lives out of tree at // We can handle cases like this by adding an extra functionality key
// exclusively for handling autograd for NestedTensor.
// lives out of tree at
// https://github.com/pytorch/nestedtensor // https://github.com/pytorch/nestedtensor
// Here are some reserved pre-autograd keys for user-defined backends, see AutogradNestedTensor,
// Note [Private use DispatchKey]
AutogradPrivateUse1,
AutogradPrivateUse2,
AutogradPrivateUse3,
Tracer, Tracer,
@ -299,9 +371,100 @@ enum class DispatchKey : uint8_t {
TESTING_ONLY_GenericMode, TESTING_ONLY_GenericMode,
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
NumDispatchKeys, // Sentinel, end of runtime keys. EndOfFunctionalityKeys, // End of functionality keys.
// ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ //
// Here are backends which you think of as traditionally specifying
// how to implement operations on some device.
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
StartOfDenseBackends,
CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp
CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp
HIP, // NB: I think this is not actually used, due to Note [Masquerading as
// CUDA]
XLA, // lives out of tree at https://github.com/pytorch/xla
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
XPU, // For out of tree Intel's heterogeneous computing plug-in
HPU, // For out of tree & closed source integration of HPU / Habana
VE, // For out of tree & closed source integration of SX-Aurora / NEC
Lazy, // For lazy tensor backends
// Here are reserved backends for user-defined backends, see Note [Private use
// DispatchKey]
// To see some example about how to use this, check out ORT
PrivateUse1,
PrivateUse2,
PrivateUse3,
EndOfDenseBackends = PrivateUse3,
// ~~~~~~~~~~~~~~ "Quantized" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~ //
// keys starting with an _ are not currently used,
// but are needed to ensure that every backend is indexed correctly.
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
StartOfQuantizedBackends,
QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp
QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp
_QuantizedHIP,
_QuantizedXLA,
_QuantizedMLC,
QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in
_QuantizedHPU,
_QuantizedVE,
_QuantizedLazy,
_QuantizedPrivateUse1,
_QuantizedPrivateUse2,
_QuantizedPrivateUse3,
EndOfQuantizedBackends = _QuantizedPrivateUse3,
// ~~~~~~~~~~~~~~ "Sparse" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~ //
// keys starting with an _ are not currently used,
// but are needed to ensure that every backend is indexed correctly.
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
StartOfSparseBackends,
SparseCPU, // registered at build/aten/src/ATen/RegisterSparseCPU.cpp
SparseCUDA, // registered at build/aten/src/ATen/RegisterSparseCUDA.cpp
SparseHIP, // TODO: I think this is not actually used, due to Note
// [Masquerading as CUDA]
_SparseXLA,
_SparseMLC,
SparseXPU, // For out of tree Intel's heterogeneous computing plug-in
_SparseHPU,
SparseVE, // For out of tree & closed source integration of SX-Aurora / NEC
_SparseLazy,
_SparsePrivateUse1,
_SparsePrivateUse2,
_SparsePrivateUse3,
EndOfSparseBackends = _SparsePrivateUse3,
// ~~~~~~~~~~~~~~ "Autograd" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~ //
// keys starting with an _ are not currently used,
// but are needed to ensure that every backend is indexed correctly.
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
StartOfAutogradBackends,
AutogradCPU,
AutogradCUDA,
_AutogradHIP,
AutogradXLA,
AutogradMLC,
AutogradXPU,
AutogradHPU,
_AutogradVE,
AutogradLazy,
// Here are some reserved pre-autograd keys for user-defined backends, see
// Note [Private use DispatchKey]
AutogradPrivateUse1,
AutogradPrivateUse2,
AutogradPrivateUse3,
EndOfAutogradBackends = AutogradPrivateUse3,
// If we add a new per-backend functionality key that has higher priority
// than Autograd, then this key should be updated.
EndOfRuntimeBackendKeys = EndOfAutogradBackends,
// ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ // // ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// Note [Alias Dispatch Keys]
// Alias dispatch keys are synthetic dispatch keys which map to multiple // Alias dispatch keys are synthetic dispatch keys which map to multiple
// runtime dispatch keys. Alisa keys have precedence, but they are always // runtime dispatch keys. Alisa keys have precedence, but they are always
// lower precedence than runtime keys. You can register a kernel to an // lower precedence than runtime keys. You can register a kernel to an
@ -321,6 +484,7 @@ enum class DispatchKey : uint8_t {
// Define an alias key to represent end of alias dispatch keys. // Define an alias key to represent end of alias dispatch keys.
// If you add new alias keys after Autograd, please also update it here. // If you add new alias keys after Autograd, please also update it here.
StartOfAliasKeys = Autograd,
EndOfAliasKeys = CompositeExplicitAutograd, // EndOfAliasKeys = CompositeExplicitAutograd, //
// ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // // ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
@ -360,54 +524,83 @@ enum class DispatchKey : uint8_t {
// built-in autograd formulas for operators are not appropriate. // built-in autograd formulas for operators are not appropriate.
static_assert( static_assert(
static_cast<uint8_t>(DispatchKey::NumDispatchKeys) < 64, (static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) +
"DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries"); static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys)) <= 64,
"The BackendComponent and DispatchKey enums (below EndOfFunctionalityKeys)"
" both map to backend and functionality bits"
" into a 64-bit bitmask; you must have less than 64 total entries between them");
// Check if a DispatchKey is an alias mapping to other runtime keys.
constexpr bool isAliasDispatchKey(DispatchKey k) {
return k >= DispatchKey::StartOfAliasKeys && k <= DispatchKey::EndOfAliasKeys;
}
// [Note: Per-Backend Functionality Dispatch Keys]
// Check if a DispatchKey is a per-backend functionality key
// Any functionalities that can be customized per-backend should be added here.
// These keys correspond to functionalities that can be customized indivually
// per backend. While they only take up one bit in the `DispatchKeySet` bitset,
// they map to (# backends) slots in the operator table.
// Each of these keys also has a separate set of "runtime keys" in the dispatch
// key enum, per backend, which *do* map to the individual operator table slots.
// For example, the "Sparse" key maps to an individual bit in the
// DispatchKeySet, while `SparseCPU`, `SparseCUDA`, etc all map to individual
// slots in the runtime operator table.
constexpr bool isPerBackendFunctionalityKey(DispatchKey k) {
if (k == DispatchKey::Dense || k == DispatchKey::Quantized ||
k == DispatchKey::Sparse || k == DispatchKey::AutogradFunctionality) {
return true;
} else {
return false;
}
}
// Note that this includes Undefined in the total count.
// BUT EndOfFunctionalityKeys is its own (placeholder) key.
// e.g. Undefined=0, Dense=1, Sparse=2, EndOfFunctionalityKeys=3.
// In the above example, there are 3 total functionality keys.
constexpr uint8_t num_functionality_keys =
static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys);
// Note [No More Than 16 Backends]
// Search for this note to find places in the code where the "no more than 16
// backends" invariant is baked in.
static_assert(
static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) <= 16,
"BackendComponent currently only supports <= 16 backends. If we really need to extend this, \
there are a few places where this invariant is baked in");
constexpr uint8_t numPerBackendFunctionalityKeys() {
uint8_t count = 0;
for (uint8_t k = 0; k <= num_functionality_keys; ++k) {
if (isPerBackendFunctionalityKey(static_cast<DispatchKey>(k)))
++count;
}
return count;
}
#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS) #if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
/** // See [Note: Trimmed Mobile Dispatch Keys]
* The method below maps the dispatch key in the enum DispatchKey to an constexpr uint8_t num_backends = 1; // Only CPU
* integer index in the dispatchTable_ array in OperatorEntry. The array constexpr uint16_t num_runtime_entries = 8;
* is trimmed for mobile to reduce peak memory usage since it's
* unnecessary to reserve additional space for dispatch keys that will
* never be used on mobile.
*/
C10_API constexpr int getDispatchTableIndexForDispatchKey(DispatchKey dk) {
switch (dk) {
case DispatchKey::Undefined:
return 0;
case DispatchKey::CPU:
return 1;
case DispatchKey::QuantizedCPU:
return 2;
case DispatchKey::SparseCPU:
return 3;
case DispatchKey::BackendSelect:
return 4;
case DispatchKey::ADInplaceOrView:
return 5;
case DispatchKey::AutogradOther:
return 6;
case DispatchKey::AutogradCPU:
return 7;
case DispatchKey::NumDispatchKeys: // Sentinel, end of runtime keys.
return 8;
default:
return -1;
}
}
#else #else
/** constexpr uint8_t num_backends =
* For the server use-case, make this a simple pass-through. static_cast<uint8_t>(BackendComponent::EndOfBackendKeys);
*/ constexpr uint16_t num_runtime_entries = num_functionality_keys +
C10_API constexpr int getDispatchTableIndexForDispatchKey(DispatchKey dk) { (numPerBackendFunctionalityKeys() * (num_backends - 1));
return static_cast<int>(dk);
}
#endif #endif
C10_API const char* toString(DispatchKey); // See Note [No More Than 16 Backends]
C10_API std::ostream& operator<<(std::ostream&, DispatchKey); constexpr uint16_t full_backend_mask =
(static_cast<uint16_t>(1) << num_backends) - 1;
C10_API DispatchKey getAutogradKeyFromBackend(DispatchKey t); C10_API const char* toString(DispatchKey);
C10_API const char* toString(BackendComponent);
C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
C10_API std::ostream& operator<<(std::ostream&, BackendComponent);
C10_API DispatchKey getAutogradKeyFromBackend(BackendComponent k);
// Parses a string into a dispatch key. // Parses a string into a dispatch key.
// If the string cannot be correctly parsed, throws an exception. // If the string cannot be correctly parsed, throws an exception.
@ -420,10 +613,86 @@ C10_API c10::DispatchKey parseDispatchKey(const std::string& k);
// torch::dispatch(torch::kCPU, ...) is also valid. // torch::dispatch(torch::kCPU, ...) is also valid.
constexpr DispatchKey kAutograd = DispatchKey::Autograd; constexpr DispatchKey kAutograd = DispatchKey::Autograd;
// Check if a DispatchKey is an alias mapping to other runtime keys. // See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
inline bool isAliasDispatchKey(DispatchKey k) { // This function relies on the invariant that the dispatch keys between
return k > DispatchKey::NumDispatchKeys && k <= DispatchKey::EndOfAliasKeys; // StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
// in the same order as `BackendComponent`.
constexpr BackendComponent toBackendComponent(DispatchKey k) {
if (k >= DispatchKey::StartOfDenseBackends &&
k <= DispatchKey::EndOfDenseBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfDenseBackends));
} else if (
k >= DispatchKey::StartOfQuantizedBackends &&
k <= DispatchKey::EndOfQuantizedBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends));
} else if (
k >= DispatchKey::StartOfSparseBackends &&
k <= DispatchKey::EndOfSparseBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfSparseBackends));
} else if (
k >= DispatchKey::StartOfAutogradBackends &&
k <= DispatchKey::EndOfAutogradBackends) {
return static_cast<BackendComponent>(
static_cast<uint8_t>(k) -
static_cast<uint8_t>(DispatchKey::StartOfAutogradBackends));
} else {
return BackendComponent::InvalidBit;
} }
}
constexpr DispatchKey toFunctionalityKey(DispatchKey k) {
if (k <= DispatchKey::EndOfFunctionalityKeys) {
return k;
} else if (k <= DispatchKey::EndOfDenseBackends) {
return DispatchKey::Dense;
} else if (k <= DispatchKey::EndOfQuantizedBackends) {
return DispatchKey::Quantized;
} else if (k <= DispatchKey::EndOfSparseBackends) {
return DispatchKey::Sparse;
} else if (k <= DispatchKey::EndOfAutogradBackends) {
return DispatchKey::AutogradFunctionality;
} else {
return DispatchKey::Undefined;
}
}
// Given (DispatchKey::Dense, DispatchKey::CUDABit), returns DispatchKey::CUDA
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
// This function relies on the invariant that the dispatch keys between
// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
// in the same order as `BackendComponent`.
constexpr DispatchKey toRuntimePerBackendFunctionalityKey(
DispatchKey functionality_k,
BackendComponent backend_k) {
if (functionality_k == DispatchKey::Dense) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfDenseBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::Sparse) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfSparseBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::Quantized) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends) +
static_cast<uint8_t>(backend_k));
}
if (functionality_k == DispatchKey::AutogradFunctionality) {
return static_cast<DispatchKey>(
static_cast<uint8_t>(DispatchKey::StartOfAutogradBackends) +
static_cast<uint8_t>(backend_k));
}
return DispatchKey::Undefined;
}
} // namespace c10 } // namespace c10
namespace torch { namespace torch {

View File

@ -1,37 +1,29 @@
#include <c10/core/DispatchKeySet.h> #include <c10/core/DispatchKeySet.h>
#include <c10/util/irange.h>
namespace c10 { namespace c10 {
// backend_dispatch_keyset should include all runtime backend keys. // backend_dispatch_keyset includes all dispatch keys that map to backends.
// Alias key DispatchKey::CompositeExplicitAutograd maps to // Alias key DispatchKey::CompositeExplicitAutograd maps to
// backend_dispatch_keyset NestedTensor has been explicitly removed due to // backend_dispatch_keyset
// incompatibility with some kernels, such as structured kernels, that use the constexpr DispatchKeySet backend_dispatch_keyset =
// DefaultBackend key. autogradother_backends | DispatchKeySet(DispatchKey::Dense);
constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
DispatchKeySet({
DispatchKey::CPU,
DispatchKey::CUDA,
DispatchKey::XLA,
DispatchKey::Lazy,
DispatchKey::XPU,
DispatchKey::PrivateUse1,
DispatchKey::PrivateUse2,
DispatchKey::PrivateUse3,
DispatchKey::MLC,
DispatchKey::HPU,
DispatchKey::ORT,
DispatchKey::Meta,
});
bool isBackendDispatchKey(DispatchKey t) { bool isBackendDispatchKey(DispatchKey t) {
return t != DispatchKey::Undefined return t != DispatchKey::Undefined
// See Note [No Alias Keys in DispatchKeySet] // See Note [No Alias Keys in DispatchKeySet]
&& !isAliasDispatchKey(t) && backend_dispatch_keyset.has(t); && !isAliasDispatchKey(t)
// Note [NestedTensor Not Included in Backend Keys]
// NestedTensor has been explicitly removed from the "backend keyset" due
// to incompatibility with some kernels, so we don't want it to be
// included in CompositeImplicitAutograd or CompositeExplicitAutograd
// kernels.
&& t != DispatchKey::NestedTensor && backend_dispatch_keyset.has(t);
} }
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and // math_dispatch_keyset contains all keys in backend_dispatch_keyset and
// autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd // autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd
// maps to math_dispatch_keyset. // maps to [math_dispatch_keyset x full_backend_mask]
constexpr DispatchKeySet math_dispatch_keyset = constexpr DispatchKeySet math_dispatch_keyset =
backend_dispatch_keyset | autograd_dispatch_keyset; backend_dispatch_keyset | autograd_dispatch_keyset;
@ -39,7 +31,12 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) { switch (t) {
case DispatchKey::Autograd: case DispatchKey::Autograd:
return autograd_dispatch_keyset; // See Note [autograd_dispatch_keyset Does Not Include Backend Bits]
// That's why we OR it with a mask of the backend bits here.
// getRuntimeDispatchKeySet() expects to return a keyset of runtime
// dispatch keys, like AutogradCPU, but that requires having backend bits.
return autograd_dispatch_keyset |
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
case DispatchKey::CompositeImplicitAutograd: case DispatchKey::CompositeImplicitAutograd:
return math_dispatch_keyset; return math_dispatch_keyset;
case DispatchKey::CompositeExplicitAutograd: case DispatchKey::CompositeExplicitAutograd:
@ -53,11 +50,13 @@ bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) { switch (t) {
case DispatchKey::Autograd: case DispatchKey::Autograd:
return autograd_dispatch_keyset.has(k); return autograd_dispatch_keyset.has(toFunctionalityKey(k));
case DispatchKey::CompositeImplicitAutograd: case DispatchKey::CompositeImplicitAutograd:
return math_dispatch_keyset.has(k); // See Note [NestedTensor Not Included in Backend Keys]
return k != DispatchKey::NestedTensor && math_dispatch_keyset.has(k);
case DispatchKey::CompositeExplicitAutograd: case DispatchKey::CompositeExplicitAutograd:
return backend_dispatch_keyset.has(k); // See Note [NestedTensor Not Included in Backend Keys]
return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k);
default: default:
return t == k; return t == k;
} }
@ -79,8 +78,6 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
return DispatchKeySet(DispatchKey::MLC); return DispatchKeySet(DispatchKey::MLC);
case DispatchKey::AutogradHPU: case DispatchKey::AutogradHPU:
return DispatchKeySet(DispatchKey::HPU); return DispatchKeySet(DispatchKey::HPU);
case DispatchKey::AutogradNestedTensor:
return DispatchKeySet(DispatchKey::NestedTensor);
case DispatchKey::AutogradXPU: case DispatchKey::AutogradXPU:
return DispatchKeySet(DispatchKey::XPU); return DispatchKeySet(DispatchKey::XPU);
case DispatchKey::AutogradPrivateUse1: case DispatchKey::AutogradPrivateUse1:
@ -96,23 +93,6 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
} }
} }
DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
switch (t) {
case DispatchKey::CPU:
return DispatchKeySet(DispatchKey::AutocastCPU);
case DispatchKey::CUDA:
case DispatchKey::XLA:
return DispatchKeySet(DispatchKey::AutocastCUDA);
default:
return DispatchKeySet();
}
}
DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) {
return DispatchKeySet(
{DispatchKey::ADInplaceOrView, getAutogradKeyFromBackend(t)});
}
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) { bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
return k != DispatchKey::Undefined && runtimeDispatchKeySetHas(alias, k); return k != DispatchKey::Undefined && runtimeDispatchKeySetHas(alias, k);
} }
@ -129,18 +109,167 @@ std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
return os; return os;
} }
os << "DispatchKeySet("; os << "DispatchKeySet(";
DispatchKey tid;
bool first = true; bool first = true;
while ((tid = ts.highestPriorityTypeId()) != DispatchKey::Undefined) { for (auto k : ts) {
if (!first) { if (!first) {
os << ", "; os << ", ";
} }
os << tid; os << k;
ts = ts.remove(tid);
first = false; first = false;
} }
os << ")"; os << ")";
return os; return os;
} }
DispatchKeySet::iterator& DispatchKeySet::iterator::operator++() {
TORCH_INTERNAL_ASSERT(next_functionality_ >= num_backends);
TORCH_INTERNAL_ASSERT(next_functionality_ <= iterator::end_iter_mask_val);
TORCH_INTERNAL_ASSERT(next_backend_ <= num_backends);
// Create a masked version of the set representation to ignore previous
// keys that we've iterated through.
uint64_t masked_functionality_bits =
llvm::maskTrailingZeros<uint64_t>(next_functionality_) & *data_ptr_;
uint64_t masked_backend_bits =
llvm::maskTrailingZeros<uint64_t>(next_backend_) & full_backend_mask &
*data_ptr_;
uint64_t first_functionality_idx =
llvm::findFirstSet(masked_functionality_bits);
uint64_t first_backendcomponent_idx = llvm::findFirstSet(masked_backend_bits);
// If there are no keys, set to end iterator value
if (first_functionality_idx == std::numeric_limits<uint64_t>::max() ||
next_functionality_ == iterator::end_iter_mask_val) {
// Set up state to be the same as end()
next_functionality_ = iterator::end_iter_mask_val;
current_dispatchkey_idx_ = iterator::end_iter_key_val;
next_backend_ = 0;
current_backendcomponent_idx_ = iterator::end_iter_key_val;
return *this;
}
// The +1 is because of DispatchKey::Undefined and
// BackendComponent::InvalidBit
auto new_next_functionality = first_functionality_idx + 1;
auto new_backendcomponent_idx = first_backendcomponent_idx + 1;
// and the -num_backends is because the first <num_backends> bits in the
// keyset are not Dispatch Keys.
auto next_dispatchkey_idx = new_next_functionality - num_backends;
// If the current functionality bit is a per-backend bit, we need special
// handling
if (isPerBackendFunctionalityKey(
static_cast<DispatchKey>(next_dispatchkey_idx))) {
// case 1: if the current backend is undefined, then there is no valid
// backend instance of this functionality key so we can skip it.
if (first_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
// increment the functionality mask so we skip the current functionality
// bit on the next increment.
next_functionality_ = new_next_functionality;
++(*this);
return *this;
}
// Otherwise, at this point we know what the current backend and
// functionality bits are.
current_dispatchkey_idx_ = next_dispatchkey_idx;
current_backendcomponent_idx_ = new_backendcomponent_idx;
// Next, we need to set up the masks for the next increment.
uint64_t next_backendcomponent_bits =
llvm::maskTrailingZeros<uint64_t>(first_backendcomponent_idx + 1) &
full_backend_mask & *data_ptr_;
uint64_t next_backendcomponent_idx =
llvm::findFirstSet(next_backendcomponent_bits);
if (next_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
// case 2: the current backend is valid, but there is not another backend
// in the keyset. In this case, we need to bump the functionality mask and
// reset the backend mask for the next increment
next_functionality_ = new_next_functionality;
next_backend_ = 0;
} else {
// case 3: we have another backend to iterate over. We want to iterate
// over the same functionality bit next time, but a different backend bit.
next_backend_ = first_backendcomponent_idx + 1;
}
} else {
// Functionality bits that aren't per backend are simpler to handle. We can
// ignore the backend bits.
TORCH_INTERNAL_ASSERT(next_backend_ == 0);
current_dispatchkey_idx_ = next_dispatchkey_idx;
next_functionality_ = new_next_functionality;
}
return *this;
}
std::array<FunctionalityOffsetAndMask, num_functionality_keys>
initializeFunctionalityOffsetsAndMasks() {
std::array<FunctionalityOffsetAndMask, num_functionality_keys>
offsets_and_masks;
// manualy set the first entry, which corresponds to Undefined.
offsets_and_masks[0] = FunctionalityOffsetAndMask(0, 0);
// loop through every functionality key (aside from Undefined).
for (const auto functionality_idx : c10::irange(1, num_functionality_keys)) {
// functionality_idx should be Dense -> 1, ...
auto prev_offset_and_mask = offsets_and_masks[functionality_idx - 1];
auto k = static_cast<DispatchKey>(functionality_idx);
#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
// [Note: Trimmed Mobile Dispatch Keys]
uint16_t mask = 0;
uint16_t offset = 0;
switch (k) {
case DispatchKey::Undefined:
offset = 0;
case DispatchKey::CPU:
offset = 1;
case DispatchKey::QuantizedCPU:
offset = 2;
case DispatchKey::SparseCPU:
offset = 3;
case DispatchKey::BackendSelect:
offset = 4;
case DispatchKey::ADInplaceOrView:
offset = 5;
case DispatchKey::AutogradOther:
offset = 6;
case DispatchKey::AutogradCPU:
offset = 7;
default:
// All other keys which are unsupported on mobile will get sent
// to the undefined kernel, causing them to error.
offset = 0;
}
offsets_and_masks[functionality_idx] =
FunctionalityOffsetAndMask(offset, 0);
}
#else
// If the previous functionality was not per-backend, then we can just
// increment the previous offset. Otherwise, the next offset =
// previous_offset + num_backends.
auto next_offset = prev_offset_and_mask.offset +
(prev_offset_and_mask.mask == 0 ? 1 : num_backends);
// the mask is used in the runtime index calculation to find the offset of
// the backend. For non-per-backend functionalities, this offset should
// always be 0. Otherwise, we need to get the index of the backend (which we
// can do using a backend mask).
auto next_mask = isPerBackendFunctionalityKey(k) ? full_backend_mask : 0;
offsets_and_masks[functionality_idx] =
FunctionalityOffsetAndMask(next_offset, next_mask);
}
// Sanity check that the computed offset index of the last functionality key
// is correct. This assumes that the highest priority functionality key is not
// per backend.
TORCH_INTERNAL_ASSERT(
offsets_and_masks[num_functionality_keys - 1].offset ==
(num_runtime_entries - 1),
"num_runtime_entries: ",
num_runtime_entries,
"last_offset: ",
offsets_and_masks[num_functionality_keys - 1].offset);
#endif
return offsets_and_masks;
}
} // namespace c10 } // namespace c10

View File

@ -1,5 +1,4 @@
#pragma once #pragma once
#include <c10/core/DispatchKey.h> #include <c10/core/DispatchKey.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/Metaprogramming.h> #include <c10/util/Metaprogramming.h>
@ -8,29 +7,147 @@
namespace c10 { namespace c10 {
struct FunctionalityOffsetAndMask {
// empty constructor shouldn't be used; only needed to initialize
// the array before populating it.
FunctionalityOffsetAndMask() {}
FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask)
: offset(offset), mask(mask) {}
// This needs to big enough to cover the size of the operator table.
uint16_t offset;
// See Note [No More Than 16 Backends]
// This mask needs to be big enough to mask all of the backend bits.
// We probably don't ever want to have more than 16 backend bits, so uint16_t
// should be enough.
uint16_t mask;
};
static_assert(
c10::num_runtime_entries < 65536,
"The dispatcher currently only supports up to 2^16 runtime entries");
C10_API std::array<FunctionalityOffsetAndMask, num_functionality_keys>
initializeFunctionalityOffsetsAndMasks();
C10_ALWAYS_INLINE static const std::
array<FunctionalityOffsetAndMask, num_functionality_keys>&
offsetsAndMasks() {
static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks();
return offsets_and_masks_;
}
// A representation of a set of DispatchKeys. A DispatchKeySet contains both
// "functionality" bits and "backend bits", and every tensor holds its own
// DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the
// keyset on every input tensor, oring them together, and dispatching to a
// specific piece of functionality. The functionality bits are *ordered*. When
// multiple functionality bits are set, we use the highest priority
// functionality. Similarly, multiple backend bits can theoretically be set if
// you call an operator with multiple tensors from difference devices (e.g. CPU
// and CUDA), although support for mixed device dispatch is limited (the only
// kernels that gracefully handle mixed device inputs for now are cuda kernels
// that take in a scalar cpu tensor).
// A representation of a set of DispatchKeys. A tensor may have multiple // A representation of a set of DispatchKeys. A tensor may have multiple
// tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the // tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the
// DispatchKeySet specifies what type ids apply. The internal representation is // DispatchKeySet specifies what type ids apply. The internal representation is
// as a 64-bit bit set (this means only 64 tensor type ids are supported). // as a 64-bit bit set (this means only 64 tensor type ids are supported).
// //
// Note that DispatchKeys are ordered; thus, we can ask questions like "what is // As mentioned above, DispatchKeys are ordered; thus, we can ask questions like
// the highest priority DispatchKey in the set"? (The set itself is not // "what is the highest priority DispatchKey in the set"? (The set itself is
// ordered; two sets with the same ids will always have the ids ordered in the // not ordered; two sets with the same ids will always have the ids ordered in
// same way.) // the same way.)
// //
// At the moment, there are no nontrivial uses of this set; tensors are always // Note [DispatchKeySet Internal Representation]
// singletons. In the near future, this set will represent variable? + tensor // Internally, dispatch keys are packed into 64-bit DispatchKeySet objects
// type id. In the far future, it will be requires grad? + profiling? + // that get passed around at runtime.
// tracing? + lazy? + tensor type id. // However, there isn't necessarily a 1-to-1 mapping between bits in the keyset
// and individual dispatch keys.
// //
// (The difference between variable and requires grad, is that // First: why do we have this distinction, and why not map every dispatch key
// there are currently three states a tensor can be: // directly to a bit? This is mostly because we have several types of
// 1. Not a variable // functionalities that different backends would like to customize. For example,
// 2. Variable with requires_grad=False // we have:
// 3. Variable with requires_grad=True // - "Dense": CPU, CUDA, XLA, ... (~12 keys)
// Eventually, we want to kill state (1), and only dispatch to autograd // - "Sparse": SparseCPU, SparseCUDA, ...
// handling code if one of the inputs requires grad.) // - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ...
// - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ...
// The problem is that total number of keys grows quadratically with [#
// backends] x [# functionalities], making it very difficult to map each key
// directly to a bit in a bitset without dramatically increasing the size of the
// bitset over time.
// //
// The two enums (BackendComponent and DispatchKey) can be divided roughly into
// 5 categories.
//
// (1) "Building block" keys
// (a) backends: jEverything in the BackendComponent enum (e.g. CPUBit,
// CUDABIt) (b) functionalities: (per-backend) functionality-bit DispatchKeys
// (e.g. AutogradFunctionality, Sparse, Dense)
// (2) "Runtime" keys
// (a) "non-customizable backends" (e.g. FPGA)
// (b) "non-customizable functionalities" (e.g. Functionalize)
// (c) "per-backend instances of customizable functionalities" (e.g. CPU,
// SparseCPU, AutogradCPU)
// (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys])
//
// (1) Building block keys always correspond to individual bits in a
// DispatchKeySet. They can also be combined in a DispatchKeySet to form actual
// runtime keys. e.g.
// auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit,
// DispatchKey::Dense});
// // The keyset has the runtime dense-cpu key.
// dense_cpu_ks.has(DispatchKey::CPU);
// // And it contains the building block keys too.
// dense_cpu_ks.has(DispatchKey::CPUBit);
// dense_cpu_ks.has(DispatchKey::Dense);
//
// Not every backend and not every functionality counts as a "building block
// key". This is mostly to give us more levers to pull in the design space.
// Backend keys and functionality keys that count as "building blocks" will
// contribute to a full cross product of functionality that can be overriden.
//
// For example, right now we have at least 12 "backend" building blocks (CPU,
// CUDA, XLA, ...) and at least 4 "functionality" building blocks (Dense,
// Sparse, Quantized, AutogradFunctionality, ...). These keys together allow
// every dispatcher operator to be customized in up to 12*4 different ways. Each
// of those requires a slot in the operator table of every dispatcher operator.
// Not every piece of functionality necessarily needs to be customizeable
// per-backend, and not every backend necessarily needs to be able to customize
// every type of functionality.
//
//
// (2) Every runtime key corresponds directly to a slot in an operator's runtime
// dispatch table, and you can directly register kernels to a runtime dispatch
// key.
//
// For per-backend functionalities like "Dense" or "AutogradFunctionality",
// you can think of the corresponding runtime dispatch keys as "instances" of
// that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all
// runtime instances of the "Dense" building block key.
// (2a) and (2b) are represented identically in the DispatchKeySet logic:
// - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT
// customizeable per backend.
// In order to do so, we'd need to promote it to a per-backend functionality
// "building block" key.
// - non-customizeable backends (e.g. FPGA) can NOT customize existing
// functionality like Sparse, Autograd, etc.
// In order to do so, we'd need to promote it to a backend "building block"
// key.
//
// In both cases, these keys directly correspond to runtime slots in the
// operator table.
//
//
// (3) "Alias" keys
// See Note [Alias Dispatch Keys]
//
// Final note: for anyone making future changes to the Dispatcher +
// DispatchKeySet internals, there's a closed PR with a basic
// python-implementation of the Dispatcher that might be useful in quickly
// testing out and validating changes. See it at
// https://github.com/pytorch/pytorch/pull/68743
// An undefined tensor is one with an empty tensor type set. // An undefined tensor is one with an empty tensor type set.
class DispatchKeySet final { class DispatchKeySet final {
public: public:
@ -41,29 +158,146 @@ class DispatchKeySet final {
// NB: default constructor representation as zero is MANDATORY as // NB: default constructor representation as zero is MANDATORY as
// use of DispatchKeySet in TLS requires this. // use of DispatchKeySet in TLS requires this.
constexpr DispatchKeySet() : repr_(0) {} constexpr DispatchKeySet() : repr_(0) {}
constexpr DispatchKeySet(Full) constexpr DispatchKeySet(Full)
: repr_(std::numeric_limits<decltype(repr_)>::max()) {} : repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {}
constexpr DispatchKeySet(FullAfter, DispatchKey t) constexpr DispatchKeySet(FullAfter, DispatchKey t)
// LSB after t are OK, but not t itself. // LSB after t are OK, but not t itself.
: repr_((1ULL << (static_cast<uint8_t>(t) - 1)) - 1) {} // "functionalities" have a notion of ordering (e.g. Autograd > Sparse >
// Quantized > Dense). But backends don't really have an ordering.
// Therefore, we're enforcing that FullAfter can only be used on
// "functionality" keys.
: repr_(
(1ULL
<< (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) -
1)) -
1) {}
// Public version of DispatchKeySet(uint64_t) API; external users // Public version of DispatchKeySet(uint64_t) API; external users
// must be explicit when they do this! // must be explicit when they do this!
constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {} constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {}
explicit constexpr DispatchKeySet(DispatchKey t)
: repr_( constexpr explicit DispatchKeySet(BackendComponent k) {
t == DispatchKey::Undefined if (k == BackendComponent::InvalidBit) {
repr_ = 0;
} else {
repr_ = 1ULL << (static_cast<uint8_t>(k) - 1);
}
}
constexpr explicit DispatchKeySet(DispatchKey k) {
if (k == DispatchKey::Undefined) {
// Case 1: handle Undefined specifically
repr_ = 0;
} else if (k <= DispatchKey::EndOfFunctionalityKeys) {
// Case 2: handle "functionality-only" keys
// These keys have a functionality bit set, but no backend bits
// These can technically be either:
// - valid runtime keys (e.g. DispatchKey::AutogradOther,
// DispatchKey::FuncTorchBatched, etc)
// - "building block" keys that aren't actual runtime keys (e.g.
// DispatchKey::Dense or Sparse)
uint64_t functionality_val = 1ULL
<< (num_backends + static_cast<uint8_t>(k) - 1);
repr_ = functionality_val;
} else if (k <= DispatchKey::EndOfRuntimeBackendKeys) {
// Case 3: "runtime" keys that have a functionality bit AND a backend bit.
// First compute which bit to flip for the functionality.
auto functionality_k = toFunctionalityKey(k);
// The - 1 is because Undefined is technically a "functionality" that
// doesn't show up in the bitset. So e.g. Dense is technically the second
// functionality, but the lowest functionality bit.
uint64_t functionality_val = 1ULL
<< (num_backends + static_cast<uint8_t>(functionality_k) - 1);
// then compute which bit to flip for the backend
// Case 4a: handle the runtime instances of "per-backend functionality"
// keys For example, given DispatchKey::CPU, we should set:
// - the Dense functionality bit
// - the CPUBit backend bit
// first compute which bit to flip for the backend
auto backend_k = toBackendComponent(k);
uint64_t backend_val = backend_k == BackendComponent::InvalidBit
? 0 ? 0
: 1ULL << (static_cast<uint8_t>(t) - 1)) {} : 1ULL << (static_cast<uint8_t>(backend_k) - 1);
explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks) repr_ = functionality_val + backend_val;
: repr_(0) { } else {
// At this point, we should have covered every case except for alias keys.
// Technically it would be possible to add alias dispatch keys to a
// DispatchKeySet, but the semantics are a little confusing and this
// currently isn't needed anywhere.
repr_ = 0;
}
}
constexpr uint64_t keys_to_repr(std::initializer_list<DispatchKey> ks) {
uint64_t repr = 0;
for (auto k : ks) { for (auto k : ks) {
repr_ |= DispatchKeySet(k).repr_; repr |= DispatchKeySet(k).repr_;
} }
return repr;
} }
constexpr uint64_t backend_bits_to_repr(
std::initializer_list<BackendComponent> ks) {
uint64_t repr = 0;
for (auto k : ks) {
repr |= DispatchKeySet(k).repr_;
}
return repr;
}
explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
: repr_(keys_to_repr(ks)) {}
explicit constexpr DispatchKeySet(std::initializer_list<BackendComponent> ks)
// Note: for some reason, putting this logic directly in the constructor
// appears to fail to compile on CUDA 10.1.
// See an example internal failure at
// https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr
: repr_(backend_bits_to_repr(ks)) {}
// Test if a DispatchKey is in the set // Test if a DispatchKey is in the set
bool inline has(DispatchKey t) const { inline bool has(DispatchKey t) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined);
return static_cast<bool>(repr_ & DispatchKeySet(t).repr_); return has_all(DispatchKeySet(t));
}
constexpr bool has_backend(BackendComponent t) const {
return has_all(DispatchKeySet(t));
}
// Test if a DispatchKey is in the set
// Given a DispatchKeySet of functionality keys and (potentially) backend
// keys, tests if all of them are in the current set.
constexpr bool has_all(DispatchKeySet ks) const {
return static_cast<bool>((repr_ & ks.repr_) == ks.repr_);
}
// Given a DispatchKeySet of functionality keys and (potentially) backend
// keys, tests if any of them are in the current set. This could technically
// be pretty easily implemented using has(). It is strictly a perf
// optimization though. There are many places in the code base where we want
// to test for multiple functionality keys together. HOWEVER, runtime
// per-backend functionality keys aren't allowed to be used with this
// function, because you can end up with weird results. e.g.
// DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU))
// would return true.
inline bool has_any(DispatchKeySet ks) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
// Either there are no backend bits in the input keyset
((ks.repr_ & full_backend_mask) == 0) ||
// or there are no per-backend-functionality bits
// See [Note: Per-Backend Functionality Dispatch Keys]
((ks &
DispatchKeySet({
DispatchKey::Dense,
DispatchKey::Quantized,
DispatchKey::Sparse,
DispatchKey::AutogradFunctionality,
})
.repr_) == 0));
return static_cast<bool>((repr_ & ks.repr_) != 0);
} }
// Test if DispatchKeySet is a superset of ks. // Test if DispatchKeySet is a superset of ks.
bool isSupersetOf(DispatchKeySet ks) const { bool isSupersetOf(DispatchKeySet ks) const {
@ -74,31 +308,64 @@ class DispatchKeySet final {
return DispatchKeySet(repr_ | other.repr_); return DispatchKeySet(repr_ | other.repr_);
} }
// Perform set intersection // Perform set intersection
DispatchKeySet operator&(DispatchKeySet other) const { constexpr DispatchKeySet operator&(DispatchKeySet other) const {
return DispatchKeySet(repr_ & other.repr_); return DispatchKeySet(repr_ & other.repr_);
} }
// Compute the set difference self - other // Compute the set difference self - other,
// but ONLY for the functionality keys.
// Any backend bits set on self will remain unchanged.
// See Note [Removing keys from DispatchKeySet Only Affects Functionality
// Keys]
DispatchKeySet operator-(DispatchKeySet other) const { DispatchKeySet operator-(DispatchKeySet other) const {
return DispatchKeySet(repr_ & ~other.repr_); return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_));
} }
// Compute self ^ other // Compute self ^ other
constexpr DispatchKeySet operator^(DispatchKeySet other) const { constexpr DispatchKeySet operator^(DispatchKeySet other) const {
return DispatchKeySet(repr_ ^ other.repr_); return DispatchKeySet(repr_ ^ other.repr_);
} }
// Perform set equality
bool operator==(DispatchKeySet other) const { bool operator==(DispatchKeySet other) const {
return repr_ == other.repr_; return repr_ == other.repr_;
} }
bool operator!=(DispatchKeySet other) const {
return repr_ != other.repr_;
}
// Add a DispatchKey to the DispatchKey set. Does NOT mutate, // Add a DispatchKey to the DispatchKey set. Does NOT mutate,
// returns the extended DispatchKeySet! // returns the extended DispatchKeySet!
C10_NODISCARD DispatchKeySet add(DispatchKey t) const { C10_NODISCARD DispatchKeySet add(DispatchKey t) const {
return *this | DispatchKeySet(t); return *this | DispatchKeySet(t);
} }
// Remove a DispatchKey from the DispatchKey set. This is C10_NODISCARD DispatchKeySet add(DispatchKeySet ks) const {
// generally not an operation you should be doing (it's return *this | ks;
// used to implement operator<<) }
C10_NODISCARD constexpr DispatchKeySet remove(DispatchKey t) const {
return DispatchKeySet(repr_ & ~DispatchKeySet(t).repr_); // Remove a DispatchKey from the DispatchKey set.
// This is generally not an operation you should be doing
// (it's used to implement the printing overload, operator<<)
//
// Note [Removing keys from DispatchKeySet Only Affects Functionality Keys]
// Only functionality bits are allowed to be removed from a keyset.
// For now, we're only allowing removal of "functionality bits" from the
// keyset, which is specifically needed by the fallthrough key calculation
// logic. Why is removing backend bits problematic? Consider this example:
//
// DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA,
// DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA)
// DispatchKeySet([DispatchKey.CPU,
// DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA)
//
// What do we want to happen?
// Technically, we'd like it to be true that after removal,
// the first keyset still has the CUDA dispatch key while the second doesn't.
// Unfortunately there's no way to represent that, because the two keysets are
// represented the same way internally: functionality bits: Autograd, Dense
// backend bits: CPU, CUDA
//
// Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd"
// bit from the bitset.
constexpr DispatchKeySet remove(DispatchKey t) const {
return DispatchKeySet(
repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask));
} }
// Is the set empty? (AKA undefined tensor) // Is the set empty? (AKA undefined tensor)
bool empty() const { bool empty() const {
@ -107,22 +374,78 @@ class DispatchKeySet final {
uint64_t raw_repr() { uint64_t raw_repr() {
return repr_; return repr_;
} }
// Return the type id in this set with the highest priority (i.e.,
// is the largest in the DispatchKey enum). Intuitively, this DispatchKey highestFunctionalityKey() const {
// type id is the one that should handle dispatch (assuming there auto functionality_idx = indexOfHighestBit();
// aren't any further exclusions or inclusions). // This means that none of the functionality bits were set.
DispatchKey highestPriorityTypeId() const { if (functionality_idx < num_backends)
// TODO: If I put Undefined as entry 64 and then adjust the return DispatchKey::Undefined;
// singleton constructor to shift from the right, we can get rid of the // The first num_backend bits in the keyset don't correspond to real
// subtraction here. It's modestly more complicated to get right so I // dispatch keys.
// didn't do it for now. return static_cast<DispatchKey>(functionality_idx - num_backends);
return static_cast<DispatchKey>(64 - llvm::countLeadingZeros(repr_));
} }
DispatchKey highestPriorityBackendTypeId() const { // This is similar like toBackendComponent(DispatchKey), but less restrictive.
return (*this & // toBackendComponent() errors out if the key that it was passed has no
((1ULL << static_cast<uint8_t>(DispatchKey::EndOfBackendKeys)) - 1)) // backend bits, which is useful for error checking. We need a version of that
.highestPriorityTypeId(); // here that can also handle "fake" backends like FPGA, because they need to
// map to the AutogradOther key. For those backends, we return
// BackendComponent::InvalidBit.
BackendComponent highestBackendKey() const {
// mask to mask out functionality bits
auto backend_idx =
DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit();
// all zeros across the backend bits means that no backend bits are set.
if (backend_idx == 0)
return BackendComponent::InvalidBit;
return static_cast<BackendComponent>(backend_idx);
}
// returns the DispatchKey of highest priority in the set.
DispatchKey highestPriorityTypeId() const {
auto functionality_k = highestFunctionalityKey();
if (isPerBackendFunctionalityKey(functionality_k)) {
return toRuntimePerBackendFunctionalityKey(
functionality_k, highestBackendKey());
}
return functionality_k;
}
// Returns the index of the most-significant bit in the keyset.
// This is used to as part of the calculation into the operator table to get:
// - the highest "functionality" bit in the keyset.
// - the highest "backend" bit in the keyset.
uint8_t indexOfHighestBit() const {
return 64 - llvm::countLeadingZeros(repr_);
}
// returns the index in the operator table of highest priority key in the the
// keyset Note that we could in theory implement this using
// highestPriorityTypeId(), but this code is very hotpath and we can do it
// faster without it.
uint64_t getDispatchTableIndexForDispatchKeySet() const {
auto functionality_idx =
DispatchKeySet(repr_ >> num_backends).indexOfHighestBit();
auto offset_and_mask = offsetsAndMasks()[functionality_idx];
// Mask the functionality bits out first, then right-shift by 1.
// right-shifting by 1 because everything is zero-indexed.
// E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should
// give us an offset of 1, etc.
auto backend_idx =
DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit();
return offset_and_mask.offset + backend_idx;
}
// returns the "index" of the highest priority backend in the keyset.
// This is pretty similar to getBackendKey(), but:
// - It's hotpath code (part of the runtime bitset calculation)
// - I's returns an integer index, not an enum value
// - Everything is shifted to the right by 1.
// BackendComponent::InvalidBit is technically the lowest enum value,
// but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2,
// etc.
uint64_t getBackendIndex() const {
return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit();
} }
private: private:
@ -130,42 +453,47 @@ class DispatchKeySet final {
uint64_t repr_ = 0; uint64_t repr_ = 0;
public: public:
// STL iterator for DispatchKeySet. Iterates through all DispatchKeys in the // STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys
// set. The iterator is only invalidated by the destruction of the underlying // in the set. The iterator is only invalidated by the destruction of the
// DispatchKeySet as the iterator stores a pointer to the raw representation // underlying DispatchKeySet as the iterator stores a pointer to the raw
// of the DispatchKeySet. // representation of the DispatchKeySet. Note: When we encounter a per-backend
// functionality (e.g. Dense or Sparse), we will iterate through EVERY backend
// in the keyset, for that functionality. For example, if the next
// functionality key to iterate over is Autograd, and the backend bits in the
// keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit],
// then the next two keys we return will be DispatchKey::AutogradCPU,
// DispatchKey::AutogradCUDA (CPU first because it has lower precedence than
// CUDA in DispatchKey.h).
class iterator { class iterator {
public: public:
using self_type = iterator; using self_type = iterator;
using iterator_category = std::input_iterator_tag; using iterator_category = std::input_iterator_tag;
using value_type = DispatchKey; using value_type = DispatchKey;
using difference_type = ptrdiff_t; using difference_type = ptrdiff_t;
// final mask value should mask out the entire keyset
static const uint8_t end_iter_mask_val =
num_backends + num_functionality_keys;
// final key value should be the last DispatchKey
static const uint8_t end_iter_key_val = num_functionality_keys;
explicit iterator(const uint64_t* data_ptr, uint8_t i = 0) // current_dispatchkey_idx_ will iterate through all functionality bits.
: data_ptr_(data_ptr), i_(i) { // current_backendcomponent_idx_ will iterate through all backend bits.
explicit iterator(
const uint64_t* data_ptr,
uint8_t next_functionality = num_backends,
uint8_t next_backend = 0)
: data_ptr_(data_ptr),
next_functionality_(next_functionality),
next_backend_(next_backend),
// These are in an invalid state at construction time, and set by the
// first increment call
current_dispatchkey_idx_(end_iter_key_val),
current_backendcomponent_idx_(end_iter_key_val) {
// Go to the first key in the set // Go to the first key in the set
++(*this); ++(*this);
} }
self_type& operator++() { C10_API self_type& operator++();
TORCH_INTERNAL_ASSERT(
i_ <= static_cast<uint8_t>(DispatchKey::NumDispatchKeys));
// Create a masked version of the set representation to ignore previous
// keys that we've iterated through.
uint64_t masked_data = llvm::maskTrailingZeros<uint64_t>(i_) & *data_ptr_;
uint64_t firstKeyIndex = llvm::findFirstSet(masked_data);
// If there are no keys, set to end iterator value
if (firstKeyIndex == std::numeric_limits<uint64_t>::max() ||
i_ == static_cast<uint8_t>(DispatchKey::NumDispatchKeys)) {
i_ = static_cast<uint8_t>(DispatchKey::NumDispatchKeys);
return *this;
}
i_ = static_cast<uint8_t>(firstKeyIndex) + 1;
return *this;
}
self_type operator++(int) { self_type operator++(int) {
self_type previous_iterator = *this; self_type previous_iterator = *this;
@ -174,18 +502,50 @@ class DispatchKeySet final {
} }
bool operator==(const self_type& rhs) const { bool operator==(const self_type& rhs) const {
return i_ == rhs.i_; return next_functionality_ == rhs.next_functionality_ &&
current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ &&
next_backend_ == rhs.next_backend_ &&
current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_;
} }
bool operator!=(const self_type& rhs) const { bool operator!=(const self_type& rhs) const {
return i_ != rhs.i_; return next_functionality_ != rhs.next_functionality_ ||
current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ ||
next_backend_ != rhs.next_backend_ ||
current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_;
} }
DispatchKey operator*() const { DispatchKey operator*() const {
return static_cast<DispatchKey>(i_); auto functionality_key =
static_cast<DispatchKey>(current_dispatchkey_idx_);
if (isPerBackendFunctionalityKey(functionality_key)) {
auto next_key = toRuntimePerBackendFunctionalityKey(
functionality_key,
static_cast<BackendComponent>(current_backendcomponent_idx_));
// We expect all of the Dense, Sparse, Quantized, and Autograd keys to
// be ordered the same way with respect to their backends
TORCH_INTERNAL_ASSERT(
toBackendComponent(next_key) ==
static_cast<BackendComponent>(current_backendcomponent_idx_),
"Tried to map functionality key ",
toString(functionality_key),
" and backend bit ",
toString(
static_cast<BackendComponent>(current_backendcomponent_idx_)),
" to a runtime key, but ended up with ",
toString(next_key),
". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.",
" Please double check that enum for inconsistencies.");
return next_key;
} else {
return functionality_key;
}
} }
private: private:
const uint64_t* data_ptr_; const uint64_t* data_ptr_;
uint8_t i_; uint8_t next_functionality_;
uint8_t next_backend_;
uint8_t current_dispatchkey_idx_;
uint8_t current_backendcomponent_idx_;
}; };
public: public:
@ -195,31 +555,35 @@ class DispatchKeySet final {
return iterator(&repr_); return iterator(&repr_);
} }
// We do not need to iterate beyond NumDispatchKeys so we will treat this as // We do not need to iterate beyond EndOfFunctionalityKeys so we will treat
// the end iterator. NumDispatchKeys will always be strictly less than 64. // this as the end iterator.
iterator end() const { iterator end() const {
return iterator(&repr_, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)); return iterator(&repr_, iterator::end_iter_mask_val);
} }
}; };
C10_API std::string toString(DispatchKeySet); C10_API std::string toString(DispatchKeySet);
C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet); C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet);
// autograd_dispatch_keyset should include all runtime autograd keys. C10_API inline uint64_t getDispatchTableIndexForDispatchKey(DispatchKey k) {
// Alias key DispatchKey::Autograd maps to autograd_dispatch_keyset. return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet();
}
// Alias key DispatchKey::Autograd maps to
// (autograd_dispatch_keyset x full_backend_mask)
// NB: keys in this set also get associated with CompositeImplicitAutograd // NB: keys in this set also get associated with CompositeImplicitAutograd
//
// Note [autograd_dispatch_keyset Does Not Include Backend Bits]
// We don't want to include any backend bits (BackendComponent::CPUBit, etc)
// directly in autograd_dispatch_keyset.
// Why? keysets like autograd_dispatch_keyset are commonly used to remove
// autograd keys from a DispatchKeySet throughout the code base. However, you
// are only allowed to remove functionality bits from a keyset, not backend
// bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality
// Keys] for details. To be consistent and avoid confusion, we're explicitly
// setting up autograd_dispatch_keyset to not have any backend bits.
constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
DispatchKey::AutogradCPU, DispatchKey::AutogradFunctionality,
DispatchKey::AutogradCUDA,
DispatchKey::AutogradXLA,
DispatchKey::AutogradLazy,
DispatchKey::AutogradNestedTensor,
DispatchKey::AutogradMLC,
DispatchKey::AutogradHPU,
DispatchKey::AutogradXPU,
DispatchKey::AutogradPrivateUse1,
DispatchKey::AutogradPrivateUse2,
DispatchKey::AutogradPrivateUse3,
DispatchKey::AutogradOther, DispatchKey::AutogradOther,
}); });
@ -244,25 +608,28 @@ constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
// backend dispatch keys that map to DispatchKey::AutogradOther // backend dispatch keys that map to DispatchKey::AutogradOther
// NB: keys in this set also get associated with CompositeImplicitAutograd // NB: keys in this set also get associated with CompositeImplicitAutograd
constexpr DispatchKeySet autogradother_backends = DispatchKeySet( constexpr DispatchKeySet autogradother_backends =
{DispatchKey::HIP, DispatchKeySet(
DispatchKey::VE, // HIP and VE aren't in this list: they now have their own backend bits
DispatchKey::FPGA, // which means that they can now have their own Autograd keys.
// Technically, HIP will now redispatch to its own custom AutogradHIP
// slot in the runtime table.
{DispatchKey::FPGA,
DispatchKey::ORT, DispatchKey::ORT,
DispatchKey::Vulkan, DispatchKey::Vulkan,
DispatchKey::Metal, DispatchKey::Metal,
DispatchKey::QuantizedCPU,
DispatchKey::QuantizedCUDA,
DispatchKey::CustomRNGKeyId,
DispatchKey::MkldnnCPU,
DispatchKey::SparseCPU,
DispatchKey::SparseCUDA,
DispatchKey::SparseHIP,
DispatchKey::SparseVE,
DispatchKey::SparseXPU,
DispatchKey::SparseCsrCPU, DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA, DispatchKey::SparseCsrCUDA,
DispatchKey::Meta}); DispatchKey::CustomRNGKeyId,
DispatchKey::MkldnnCPU,
DispatchKey::Meta,
// Sparse and Quantized backends also live here.
DispatchKey::Sparse,
DispatchKey::Quantized})
// Including the backend bits because this keyset is used during op
// registration, which requires looping over all runtime autogradother
// backend keys.
| DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
// The set of dispatch keys that come after autograd // The set of dispatch keys that come after autograd
// n.b. this relies on the fact that AutogradOther is currently the lowest // n.b. this relies on the fact that AutogradOther is currently the lowest
@ -292,6 +659,36 @@ constexpr DispatchKeySet after_func_keyset =
// away with it by explicitly removing the key here. // away with it by explicitly removing the key here.
c10::DispatchKey::ADInplaceOrView); c10::DispatchKey::ADInplaceOrView);
constexpr DispatchKeySet backend_bitset_mask =
DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1);
constexpr auto inplace_or_view_ks =
DispatchKeySet(DispatchKey::ADInplaceOrView);
constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy);
constexpr auto autograd_mlc_ks = DispatchKeySet(DispatchKey::AutogradMLC);
constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU);
constexpr auto autograd_privateuse1_ks =
DispatchKeySet(DispatchKey::AutogradPrivateUse1);
constexpr auto autograd_privateuse2_ks =
DispatchKeySet(DispatchKey::AutogradPrivateUse2);
constexpr auto autograd_privateuse3_ks =
DispatchKeySet(DispatchKey::AutogradPrivateUse3);
constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther);
struct OpTableOffsetAndMask {
uint16_t offset;
uint16_t backend_mask;
};
static_assert(
num_backends <= 16,
"Right now we expect the number of backends not to exceed 16. In the (unlikely) event"
" that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too.");
// true if t is a backend dispatch key // true if t is a backend dispatch key
C10_API bool isBackendDispatchKey(DispatchKey t); C10_API bool isBackendDispatchKey(DispatchKey t);
@ -307,10 +704,53 @@ C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k);
C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t); C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
// Returns a DispatchKeySet of autograd related keys mapped to backend. // Returns a DispatchKeySet of autograd related keys mapped to backend.
C10_API DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t); // for a given backend key, use the associated autograd key.
// for non-backend keys, use AutogradOther as a default.
// Note: it's convenient and fast to return a default here rather than (say)
// returning an optional<DispatchKey>, or throwing. But it makes callers
// responsible for either a) enforcing the invariant that only backend keys
// be passed as arguments, or b) interpreting our return value carefully.
inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
switch (t) {
case BackendComponent::CPUBit:
return inplace_or_view_ks | autograd_cpu_ks;
case BackendComponent::XPUBit:
return inplace_or_view_ks | autograd_xpu_ks;
case BackendComponent::CUDABit:
return inplace_or_view_ks | autograd_cuda_ks;
case BackendComponent::XLABit:
return inplace_or_view_ks | autograd_xla_ks;
case BackendComponent::LazyBit:
return inplace_or_view_ks | autograd_lazy_ks;
case BackendComponent::MLCBit:
return inplace_or_view_ks | autograd_mlc_ks;
case BackendComponent::HPUBit:
return inplace_or_view_ks | autograd_hpu_ks;
case BackendComponent::PrivateUse1Bit:
return inplace_or_view_ks | autograd_privateuse1_ks;
case BackendComponent::PrivateUse2Bit:
return inplace_or_view_ks | autograd_privateuse2_ks;
case BackendComponent::PrivateUse3Bit:
return inplace_or_view_ks | autograd_privateuse3_ks;
default:
return inplace_or_view_ks | autograd_other_ks;
}
}
// Returns a DispatchKeySet of autocast related keys mapped to backend. // Returns a DispatchKeySet of autocast related keys mapped to backend.
C10_API DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t); inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
switch (t) {
case BackendComponent::CPUBit:
return autocast_cpu_ks;
case BackendComponent::CUDABit:
case BackendComponent::XLABit:
return autocast_cuda_ks;
default:
return DispatchKeySet();
}
}
// This API exists because we have a use case for checking // This API exists because we have a use case for checking
// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined) // getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined)

View File

@ -190,7 +190,7 @@ TensorImpl::TensorImpl(
// TODO: be more explicit about the full key set at call sites so we // TODO: be more explicit about the full key set at call sites so we
// don't have to keep recomputing it here // don't have to keep recomputing it here
DispatchKey k = key_set.highestPriorityBackendTypeId(); auto k = key_set.highestBackendKey();
key_set = key_set | getAutocastRelatedKeySetFromBackend(k); key_set = key_set | getAutocastRelatedKeySetFromBackend(k);

View File

@ -838,10 +838,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
bool is_sparse() const { bool is_sparse() const {
// NB: This method is not virtual and avoid dispatches for performance // NB: This method is not virtual and avoid dispatches for performance
// reasons. // reasons.
return key_set_.has(DispatchKey::SparseCPU) || return key_set_.has(DispatchKey::Sparse);
key_set_.has(DispatchKey::SparseCUDA) ||
key_set_.has(DispatchKey::SparseHIP) ||
key_set_.has(DispatchKey::SparseXPU);
} }
// Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR // Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR
@ -854,9 +851,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
bool is_quantized() const { bool is_quantized() const {
// NB: This method is not virtual and avoid dispatches for performance // NB: This method is not virtual and avoid dispatches for performance
// reasons. // reasons.
return key_set_.has(DispatchKey::QuantizedCPU) || return key_set_.has(DispatchKey::Quantized);
key_set_.has(DispatchKey::QuantizedCUDA) ||
key_set_.has(DispatchKey::QuantizedXPU);
} }
bool is_meta() const { bool is_meta() const {
@ -868,53 +863,46 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
bool is_cpu() const { bool is_cpu() const {
// NB: This method is not virtual and avoid dispatches for performance // NB: This method is not virtual and avoid dispatches for performance
// reasons. // reasons.
return key_set_.has(DispatchKey::CPU) || return key_set_.has_backend(BackendComponent::CPUBit) ||
key_set_.has(DispatchKey::SparseCPU) ||
key_set_.has(DispatchKey::SparseCsrCPU) || key_set_.has(DispatchKey::SparseCsrCPU) ||
key_set_.has(DispatchKey::QuantizedCPU) ||
key_set_.has(DispatchKey::MkldnnCPU); key_set_.has(DispatchKey::MkldnnCPU);
} }
bool is_cuda() const { bool is_cuda() const {
// NB: This method is not virtual and avoid dispatches for performance // NB: This method is not virtual and avoid dispatches for performance
// reasons. // reasons.
return key_set_.has(DispatchKey::CUDA) || return key_set_.has_backend(BackendComponent::CUDABit) ||
key_set_.has(DispatchKey::SparseCUDA) || key_set_.has(DispatchKey::SparseCsrCUDA);
key_set_.has(DispatchKey::SparseCsrCUDA) ||
key_set_.has(DispatchKey::QuantizedCUDA);
} }
bool is_xpu() const { bool is_xpu() const {
// NB: This method is not virtual and avoid dispatches for performance // NB: This method is not virtual and avoid dispatches for performance
// reasons. // reasons.
return key_set_.has(DispatchKey::XPU) || return key_set_.has_backend(BackendComponent::XPUBit);
key_set_.has(DispatchKey::SparseXPU) ||
key_set_.has(DispatchKey::QuantizedXPU);
} }
bool is_xla() const { bool is_xla() const {
return key_set_.has(DispatchKey::XLA); return key_set_.has_backend(BackendComponent::XLABit);
} }
bool is_hpu() const { bool is_hpu() const {
return key_set_.has(DispatchKey::HPU); return key_set_.has_backend(BackendComponent::HPUBit);
} }
bool is_lazy() const { bool is_lazy() const {
return key_set_.has(DispatchKey::Lazy); return key_set_.has_backend(BackendComponent::LazyBit);
} }
bool is_hip() const { bool is_hip() const {
// NB: This method is not virtual and avoid dispatches for performance // NB: This method is not virtual and avoid dispatches for performance
// reasons. // reasons.
return key_set_.has(DispatchKey::HIP) || return key_set_.has_backend(BackendComponent::HIPBit);
key_set_.has(DispatchKey::SparseHIP);
} }
bool is_ve() const { bool is_ve() const {
// NB: This method is not virtual and avoid dispatches for performance // NB: This method is not virtual and avoid dispatches for performance
// reasons. // reasons.
return key_set_.has(DispatchKey::VE) || key_set_.has(DispatchKey::SparseVE); return key_set_.has_backend(BackendComponent::VEBit);
} }
bool is_mkldnn() const { bool is_mkldnn() const {
@ -1548,13 +1536,22 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
*/ */
inline bool has_compatible_shallow_copy_type(DispatchKeySet from) { inline bool has_compatible_shallow_copy_type(DispatchKeySet from) {
auto is_dense = [](DispatchKeySet ts) { auto is_dense = [](DispatchKeySet ts) {
return ts.has(DispatchKey::CPU) || ts.has(DispatchKey::CUDA) || constexpr auto dense_backends = DispatchKeySet(
ts.has(DispatchKey::HIP) || ts.has(DispatchKey::XPU); {BackendComponent::CPUBit,
BackendComponent::CUDABit,
BackendComponent::HIPBit,
BackendComponent::XPUBit});
constexpr auto dense_k = DispatchKeySet(DispatchKey::Dense);
return ts.has_any(dense_k) && ts.has_any(dense_backends);
}; };
auto is_sparse = [](DispatchKeySet ts) { auto is_sparse = [](DispatchKeySet ts) {
return ts.has(DispatchKey::SparseCPU) || constexpr auto sparse_backends = DispatchKeySet(
ts.has(DispatchKey::SparseCUDA) || ts.has(DispatchKey::SparseHIP) || {BackendComponent::CPUBit,
ts.has(DispatchKey::SparseXPU); BackendComponent::CUDABit,
BackendComponent::HIPBit,
BackendComponent::XPUBit});
constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse);
return ts.has_any(sparse_k) && ts.has_any(sparse_backends);
}; };
return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) || return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) ||
(is_sparse(key_set_) && is_sparse(from)); (is_sparse(key_set_) && is_sparse(from));

View File

@ -3,25 +3,163 @@
#include <unordered_set> #include <unordered_set>
#include <c10/core/DispatchKeySet.h> #include <c10/core/DispatchKeySet.h>
#include <c10/util/irange.h>
using namespace c10; using namespace c10;
// This test exists not to be comprehensive, but to more clearly show
// what the semantics of DispatchKeySet are.
TEST(DispatchKeySet, ShowSemantics) {
// the "CPU" dispatch key is an instance of a per-backend-functionality key.
// It corresponds to "dense" functionality, "CPU" backend.
// This means that it gets a dense functionality bit, and a cpu backend bit
// set.
auto undefined_set = DispatchKeySet();
auto dense_cpu_set = DispatchKeySet(DispatchKey::CPU);
ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense));
ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit));
ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU));
auto dense_lazy_set = DispatchKeySet(DispatchKey::Lazy);
ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Dense));
ASSERT_TRUE(dense_lazy_set.has_backend(BackendComponent::LazyBit));
ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Lazy));
// You can think of "Dense/Sparse", and "CPUBit/CUDABit", as "building block"
// dispatch keys. You are allowed to directly create keysets out of them!
auto dense_cpu_set_from_building_blocks = DispatchKeySet(DispatchKey::Dense) |
DispatchKeySet(BackendComponent::CPUBit);
ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense));
ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit));
ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU));
ASSERT_EQ(dense_cpu_set, dense_cpu_set_from_building_blocks);
// Similarly, the AutogradCUDA key gets 2 bits in the keyset:
// The "Autograd" functionality bit, and the "CUDA" backend bit
auto autograd_cuda = DispatchKeySet(DispatchKey::AutogradCUDA);
ASSERT_TRUE(autograd_cuda.has(DispatchKey::AutogradFunctionality));
ASSERT_TRUE(autograd_cuda.has_backend(BackendComponent::CUDABit));
// Because DispatchKeySet uses a condensed internal representation, you cannot
// use it to represent the FULL cross product of backends and functionalities
// for example:
auto autograd_dense_cpu_cuda = DispatchKeySet(
{DispatchKey::AutogradFunctionality,
DispatchKey::Dense,
DispatchKey::CUDA,
DispatchKey::CPU});
auto fpga = DispatchKeySet(DispatchKey::FPGA);
auto fpga_and_cpu = DispatchKeySet({DispatchKey::FPGA, DispatchKey::CPU});
// this keyset has all of the building block keys:
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradFunctionality));
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::Dense));
ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CUDABit));
ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CPUBit));
// and it also has the "runtime" keys that correspond to the full
// cross-product of functionality
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU));
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU));
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CPU));
ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CUDA));
// This means that there's no way to represent a keyset with, say, only
// Autograd CUDA + Dense CPU. Instead, you should think of a keyset as
// inheriting the full set of functionalities + backends of its keys. This
// means that the below keysets are all indistinguishable from each other.
ASSERT_EQ(
autograd_dense_cpu_cuda,
DispatchKeySet(
{DispatchKey::AutogradCUDA,
DispatchKey::AutogradCPU,
DispatchKey::CUDA,
DispatchKey::CPU}));
ASSERT_EQ(
autograd_dense_cpu_cuda,
DispatchKeySet({DispatchKey::AutogradCUDA, DispatchKey::CPU}));
ASSERT_EQ(
autograd_dense_cpu_cuda,
DispatchKeySet({DispatchKey::CUDA, DispatchKey::AutogradCPU}));
// ~~~~~~~~~~ DispatchKeySet iterators ~~~~~~~~~~~
// Iterators allow you to iterate individually through the DispatchKey's in a
// DispatchKeySet
auto empty_set = DispatchKeySet();
auto t1 = empty_set.begin();
auto t2 = empty_set.end();
ASSERT_EQ(*empty_set.begin(), *empty_set.end());
// However, only keys that correspond to actual runtime indices of kernels in
// the operator table show up when you iterate through a keyset. i.e.
// DispatchKey::Dense, and BackendComponent::CPUBit won't show up in an
// iterator.
auto dense_cpu_iter = dense_cpu_set.begin();
ASSERT_EQ(*dense_cpu_iter++, DispatchKey::CPU);
ASSERT_EQ(*dense_cpu_iter, *dense_cpu_set.end());
auto autograd_dense_cpu_cuda_iter = autograd_dense_cpu_cuda.begin();
ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CPU);
ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CUDA);
ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCPU);
ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCUDA);
ASSERT_EQ(*autograd_dense_cpu_cuda_iter, *autograd_dense_cpu_cuda.end());
// But other "functionality bits" that are not defined per-backend DO get
// their own slots in the operator table.
auto mixed_keyset = DispatchKeySet(BackendComponent::CPUBit) |
DispatchKeySet(
{DispatchKey::FPGA, // runtime key
DispatchKey::Functionalize, // runtime key
DispatchKey::Dense}); // NOT a runtime key
auto mixed_iter = mixed_keyset.begin();
ASSERT_EQ(*mixed_iter++, DispatchKey::CPU);
ASSERT_EQ(*mixed_iter++, DispatchKey::FPGA);
ASSERT_EQ(*mixed_iter++, DispatchKey::Functionalize);
ASSERT_EQ(*mixed_iter, *mixed_keyset.end());
}
TEST(DispatchKeySet, Empty) { TEST(DispatchKeySet, Empty) {
DispatchKeySet empty_set; DispatchKeySet empty_set;
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); for (uint8_t i = 0;
i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
i++) { i++) {
auto tid = static_cast<DispatchKey>(i); auto tid = static_cast<DispatchKey>(i);
if (tid == DispatchKey::Undefined)
continue;
ASSERT_FALSE(empty_set.has(tid)); ASSERT_FALSE(empty_set.has(tid));
} }
ASSERT_TRUE(empty_set.empty()); ASSERT_TRUE(empty_set.empty());
DispatchKeySet empty_set2; DispatchKeySet empty_set2;
ASSERT_TRUE(empty_set == empty_set2); ASSERT_TRUE(empty_set == empty_set2);
ASSERT_EQ(empty_set.highestPriorityTypeId(), DispatchKey::Undefined);
} }
TEST(DispatchKeySet, Singleton) { // This covers all keys that correspond to a single backend bit, e.g.
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); // BackendComponent::CPUBit. Even though these are NOT runtime keys, we still
i++) { // allow adding them directly to a keyset
TEST(DispatchKeySet, SingletonBackendComponent) {
for (const auto i : c10::irange(1, num_backends)) {
auto tid = static_cast<DispatchKey>(i);
DispatchKeySet sing(tid);
ASSERT_EQ(sing, sing);
ASSERT_EQ(sing, DispatchKeySet().add(tid));
ASSERT_EQ(sing, sing.add(tid));
ASSERT_EQ(sing, sing | sing);
ASSERT_FALSE(sing.empty());
ASSERT_TRUE(sing.has(tid));
}
}
// This covers all keys that correspond to a single functionality bit:
// - runtime, not-per-backend functionality keys, e.g.
// DispatchKey::FuncTorchBatched
// - runtime, "fake backend" keys, e.g. DispatchKey::FPGA
// - NOT-runtime, per-backend functionality keys, e.g. DispatchKey::Dense
// Even though it's not a runtime key, we still allow adding it directly to a
// keyset.
// DispatchKey::
TEST(DispatchKeySet, SingletonFunctionalityKeys) {
for (const auto i : c10::irange(1, num_functionality_keys)) {
auto tid = static_cast<DispatchKey>(i); auto tid = static_cast<DispatchKey>(i);
DispatchKeySet sing(tid); DispatchKeySet sing(tid);
ASSERT_EQ(sing, sing); ASSERT_EQ(sing, sing);
@ -30,47 +168,145 @@ TEST(DispatchKeySet, Singleton) {
ASSERT_EQ(sing, sing | sing); ASSERT_EQ(sing, sing | sing);
ASSERT_FALSE(sing.empty()); ASSERT_FALSE(sing.empty());
ASSERT_TRUE(sing.has(tid)); ASSERT_TRUE(sing.has(tid));
ASSERT_EQ(sing.highestPriorityTypeId(), tid);
ASSERT_EQ(sing.remove(tid), DispatchKeySet()); ASSERT_EQ(sing.remove(tid), DispatchKeySet());
} }
} }
TEST(DispatchKeySet, Doubleton) { // This covers runtime keys that are per-backend,
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); // and take up more than one bit in a DispatchKeySet. They take up one
// functionality bit + one backend bit. e.g. CPU, CUDA, SparseCPU, SparseCUDA,
// AutogradCPU, AutogradCUDA
TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) {
for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
i++) {
auto tid = static_cast<DispatchKey>(i);
// Skip these because they aren't real keys.
if (tid == DispatchKey::StartOfDenseBackends ||
tid == DispatchKey::StartOfSparseBackends ||
tid == DispatchKey::StartOfQuantizedBackends ||
tid == DispatchKey::StartOfAutogradBackends) {
continue;
}
DispatchKeySet sing(tid);
ASSERT_EQ(sing, sing);
ASSERT_EQ(sing, DispatchKeySet().add(tid));
ASSERT_EQ(sing, sing.add(tid));
ASSERT_EQ(sing, sing | sing);
ASSERT_FALSE(sing.empty());
ASSERT_TRUE(sing.has(tid));
auto functionality_key = toFunctionalityKey(tid);
auto backend_key = toBackendComponent(tid);
// These two sets should be equivalent:
// DispatchKeySet(DispatchKey::CPU)
// DispatchKeySet({DispatchKey::Dense, BackendComponent::CPUBit})
auto expected_ks =
DispatchKeySet(functionality_key) | DispatchKeySet(backend_key);
ASSERT_EQ(sing, expected_ks);
// These two sets should be equivalent:
// DispatchKeySet(DispatchKey::CPU).remove(DispatchKey::Dense)
// DispatchKeySet(BackendComponent::CPUBit)
expected_ks = DispatchKeySet(toBackendComponent(tid));
ASSERT_EQ(sing.remove(tid), expected_ks);
}
}
TEST(DispatchKeySet, DoubletonPerBackend) {
for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
i++) { i++) {
for (uint8_t j = i + 1; for (uint8_t j = i + 1;
j < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); j <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
j++) { j++) {
ASSERT_LT(i, j); ASSERT_LT(i, j);
auto tid1 = static_cast<DispatchKey>(i); auto tid1 = static_cast<DispatchKey>(i);
auto tid2 = static_cast<DispatchKey>(j); auto tid2 = static_cast<DispatchKey>(j);
auto doub = DispatchKeySet(tid1).add(tid2);
ASSERT_EQ(doub, DispatchKeySet(tid1) | DispatchKeySet(tid2)); // Skip these because they aren't real keys.
ASSERT_TRUE(doub.has(tid1)); if (tid1 == DispatchKey::StartOfDenseBackends ||
ASSERT_TRUE(doub.has(tid2)); tid1 == DispatchKey::StartOfSparseBackends ||
ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j tid1 == DispatchKey::StartOfQuantizedBackends ||
tid1 == DispatchKey::StartOfAutogradBackends)
continue;
if (tid2 == DispatchKey::StartOfDenseBackends ||
tid2 == DispatchKey::StartOfSparseBackends ||
tid2 == DispatchKey::StartOfQuantizedBackends ||
tid2 == DispatchKey::StartOfAutogradBackends)
continue;
auto backend1 = toBackendComponent(tid1);
auto backend2 = toBackendComponent(tid2);
auto functionality1 = toFunctionalityKey(tid1);
auto functionality2 = toFunctionalityKey(tid2);
auto combined = DispatchKeySet({tid1, tid2});
// The combined set has the backend bits
ASSERT_TRUE(combined.has_backend(backend1));
ASSERT_TRUE(combined.has_backend(backend2));
// and it has the backend bits
ASSERT_TRUE(combined.has(functionality1));
ASSERT_TRUE(combined.has(functionality2));
// and it has the original two runtime keys
ASSERT_TRUE(combined.has(tid1));
ASSERT_TRUE(combined.has(tid2));
// Add all of the keys in the keyset to a real set
std::unordered_set<DispatchKey> visited_keys;
auto iter = combined.begin();
while (*iter != *combined.end()) {
visited_keys.insert(*iter);
++iter;
}
std::unordered_set<DispatchKey> expected_keys;
expected_keys.insert(
toRuntimePerBackendFunctionalityKey(functionality1, backend1));
expected_keys.insert(
toRuntimePerBackendFunctionalityKey(functionality1, backend2));
expected_keys.insert(
toRuntimePerBackendFunctionalityKey(functionality2, backend1));
expected_keys.insert(
toRuntimePerBackendFunctionalityKey(functionality2, backend2));
ASSERT_EQ(expected_keys, visited_keys);
if (backend1 == backend2 || functionality1 == functionality2) {
// We have two runtime keys, with either the same backend or the same
// per-backend functionalities. E.g. {AutogradCUDA, CUDA} or
// {AutogradCPU, AutogradCUDA} There should be 2 total runtime keys in
// this set.
ASSERT_EQ(2, visited_keys.size());
} else {
// since i and j are different keys, they should not have the same
// functionality and backend
ASSERT_TRUE(backend1 != backend2 && functionality1 != functionality2);
// We have two runtime keys, that have different backends + per-backend
// functionalities. So we should expect the full cross product of
// runtime keys to be in the set. e.g. if i = AutogradCUDA, and j = CPU,
// then combined = {AutogradCUDA, AutogradCPU, CUDA, CPU}
ASSERT_EQ(4, visited_keys.size());
}
} }
} }
} }
TEST(DispatchKeySet, Full) { TEST(DispatchKeySet, Full) {
DispatchKeySet full(DispatchKeySet::FULL); DispatchKeySet full(DispatchKeySet::FULL);
for (uint8_t i = 1; i < static_cast<uint8_t>(DispatchKey::NumDispatchKeys); for (const auto i : c10::irange(1, num_functionality_keys)) {
i++) {
auto tid = static_cast<DispatchKey>(i); auto tid = static_cast<DispatchKey>(i);
ASSERT_TRUE(full.has(tid)); ASSERT_TRUE(full.has(tid));
} }
ASSERT_FALSE(full.has(DispatchKey::EndOfFunctionalityKeys));
} }
TEST(DispatchKeySet, IteratorBasicOps) { TEST(DispatchKeySet, IteratorBasicOps) {
DispatchKeySet empty_set; DispatchKeySet empty_set;
DispatchKeySet full_set(DispatchKeySet::FULL); DispatchKeySet full_set(DispatchKeySet::FULL);
DispatchKeySet mutated_set = empty_set.add(static_cast<DispatchKey>(1)); DispatchKeySet mutated_set = empty_set.add(DispatchKey::CPU);
// Constructor + Comparison // Constructor + Comparison
ASSERT_EQ(*empty_set.begin(), DispatchKey::NumDispatchKeys); ASSERT_EQ(*empty_set.begin(), DispatchKey::EndOfFunctionalityKeys);
ASSERT_EQ(*empty_set.end(), DispatchKey::NumDispatchKeys); ASSERT_EQ(*empty_set.end(), DispatchKey::EndOfFunctionalityKeys);
ASSERT_EQ(*mutated_set.begin(), static_cast<DispatchKey>(1)); ASSERT_EQ(*mutated_set.begin(), DispatchKey::CPU);
ASSERT_TRUE(empty_set.begin() == empty_set.end()); ASSERT_TRUE(empty_set.begin() == empty_set.end());
ASSERT_TRUE(full_set.begin() != full_set.end()); ASSERT_TRUE(full_set.begin() != full_set.end());
@ -90,16 +326,37 @@ TEST(DispatchKeySet, IteratorEmpty) {
ASSERT_EQ(i, 0); ASSERT_EQ(i, 0);
} }
TEST(DispatchKeySet, IteratorCrossProduct) {
// The iterator should return all runtime keys in the set,
// including the cross product of {backends} x {functionalities}
auto ks =
DispatchKeySet({BackendComponent::CPUBit, BackendComponent::CUDABit}) |
DispatchKeySet(
{DispatchKey::Dense,
DispatchKey::FPGA,
DispatchKey::AutogradFunctionality});
auto iter = ks.begin();
// iterate through dense backends first.
ASSERT_EQ(DispatchKey::CPU, *(iter++));
ASSERT_EQ(DispatchKey::CUDA, *(iter++));
// FPGA doesn't have a backend bit, so it isn't included in the cross product.
ASSERT_EQ(DispatchKey::FPGA, *(iter++));
// iterate through the autograd keys laster.
ASSERT_EQ(DispatchKey::AutogradCPU, *(iter++));
ASSERT_EQ(DispatchKey::AutogradCUDA, *(iter++));
}
TEST(DispatchKeySet, IteratorFull) { TEST(DispatchKeySet, IteratorFull) {
DispatchKeySet full_set(DispatchKeySet::FULL); DispatchKeySet full_set(DispatchKeySet::FULL);
uint8_t i = 0; uint8_t i = 0;
for (const auto& it : full_set) { for (const auto& it : full_set) {
i++; i++;
ASSERT_TRUE(it == static_cast<DispatchKey>(i));
ASSERT_TRUE(it != DispatchKey::NumDispatchKeys);
} }
ASSERT_EQ(i, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) - 1); // Total # of runtime entries includes an entry for DispatchKey::Undefined,
// which is not included when iterating through the DispatchKeySet.
ASSERT_EQ(i, num_runtime_entries - 1);
} }
TEST(DispatchKeySet, IteratorRangeFull) { TEST(DispatchKeySet, IteratorRangeFull) {
@ -108,41 +365,61 @@ TEST(DispatchKeySet, IteratorRangeFull) {
for (DispatchKey dispatch_key : full_set) { for (DispatchKey dispatch_key : full_set) {
i++; i++;
ASSERT_TRUE(dispatch_key == static_cast<DispatchKey>(i));
} }
ASSERT_EQ(i, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) - 1); // Total # of runtime entries includes an entry for DispatchKey::Undefined,
} // which is not included when iterating through the DispatchKeySet.
ASSERT_EQ(i, num_runtime_entries - 1);
TEST(DispatchKeySet, SpecificKeys) {
DispatchKeySet keyset({
static_cast<DispatchKey>(0), // Undefined should be ignored
static_cast<DispatchKey>(4),
static_cast<DispatchKey>(10),
static_cast<DispatchKey>(15),
});
std::unordered_set<DispatchKey> visited_keys;
for (DispatchKey key : keyset) {
visited_keys.insert(key);
}
ASSERT_EQ(visited_keys.size(), 3);
ASSERT_TRUE(
visited_keys.find(static_cast<DispatchKey>(4)) != visited_keys.end());
ASSERT_TRUE(
visited_keys.find(static_cast<DispatchKey>(10)) != visited_keys.end());
ASSERT_TRUE(
visited_keys.find(static_cast<DispatchKey>(15)) != visited_keys.end());
} }
TEST(DispatchKeySet, FailAtEndIterator) { TEST(DispatchKeySet, FailAtEndIterator) {
DispatchKeySet full_set(DispatchKeySet::FULL); DispatchKeySet full_set(DispatchKeySet::FULL);
uint64_t raw_repr = full_set.raw_repr(); uint64_t raw_repr = full_set.raw_repr();
// doesn't throw
DispatchKeySet::iterator(&raw_repr, num_backends + num_functionality_keys);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
EXPECT_THROW( EXPECT_THROW(
DispatchKeySet::iterator( DispatchKeySet::iterator(
&raw_repr, static_cast<uint8_t>(DispatchKey::NumDispatchKeys) + 1), &raw_repr, num_backends + num_functionality_keys + 1),
c10::Error); c10::Error);
} }
TEST(DispatchKeySet, TestKeyOrderingInvariants) {
for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
i++) {
auto k = static_cast<DispatchKey>(i);
// Note [The Ordering of Per-Backend Dispatch Keys Matters!]
// The DispatchKey enum includes all of the runtime keys for
// Dense/Sparse/Quantized/Autograd, (e.g. CPU, CUDA, SparseCPU, SparseCUDA,
// AutogradCPU, AutogradCUDA, etc). And we expect the ordering of those keys
// to be the same as the ordering of the backends in the `BackendComponent`
// enum. This makes several utilities in `DispatchKey.h` and
// `DispatchKeySet.h` significantly easier to implement. The purpose of the
// test is to assert (through CI) that this invariant is maintained.
//
// The only way that we can really check this invariant is by
// comparing the string names of each enum.
// We only really care about the ordering for "real" keys that are actually
// used, which we expect to be able to print properly. This saves us from
// having to enumerate the full set of possible runtime keys in
// DispatchKey::toString(). It also relies on toString() being implemented
// correctly.
auto functionality_str = std::string(toString(k));
if (functionality_str == "UNKNOWN_TENSOR_TYPE_ID")
continue;
auto computed_backend_k = toBackendComponent(k);
auto computed_backend_str = std::string(toString(computed_backend_k));
// Skip, e.g., the "Bit" from "CPUBit"
computed_backend_str =
computed_backend_str.substr(0, computed_backend_str.size() - 3);
ASSERT_TRUE(
functionality_str.find(computed_backend_str) != std::string::npos)
<< "DispatchKey invariant broken! Found a key that is not ordered correctly"
<< " with its backend bit. key = " << toString(k) << ", " << k
<< ", computed backend = " << toString(computed_backend_k);
}
}

View File

@ -532,8 +532,8 @@ AutogradXLA: fn_math [math kernel]
lambda m: m.def_("foo(Tensor x) -> Tensor"), lambda m: m.def_("foo(Tensor x) -> Tensor"),
# m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x })
lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd", debug="fn_math"), lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd", debug="fn_math"),
# m.impl("foo", torch::kQuantizedCPU, [](const Tensor & x) { return x }) # m.impl("foo", torch::kFPGA, [](const Tensor & x) { return x })
lambda m: m.impl_t_t("foo", "QuantizedCPU", debug="fn_quantizedcpu"), lambda m: m.impl_t_t("foo", "FPGA", debug="fn_fpga"),
]) ])
state, table = result.state, result.table state, table = result.state, result.table
self.assertExpectedInline(state, '''\ self.assertExpectedInline(state, '''\
@ -541,12 +541,12 @@ name: test::foo
schema: test::foo(Tensor x) -> (Tensor) schema: test::foo(Tensor x) -> (Tensor)
debug: registered at /dev/null:0 debug: registered at /dev/null:0
alias analysis kind: FROM_SCHEMA alias analysis kind: FROM_SCHEMA
QuantizedCPU: fn_quantizedcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] FPGA: fn_fpga :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
''') ''')
# computed dispatch table is too big, so we only check on a few entries we're interested in. # computed dispatch table is too big, so we only check on a few entries we're interested in.
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',)) extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('FPGA',))
self.assertExpectedInline(extracted_table, '''\ self.assertExpectedInline(extracted_table, '''\
Undefined: fn_math [math kernel] Undefined: fn_math [math kernel]
@ -557,7 +557,7 @@ AutogradOther: ambiguous_autogradother [ambiguous autogradother]
AutogradCPU: fn_math [math kernel] AutogradCPU: fn_math [math kernel]
AutogradCUDA: fn_math [math kernel] AutogradCUDA: fn_math [math kernel]
AutogradXLA: fn_math [math kernel] AutogradXLA: fn_math [math kernel]
QuantizedCPU: fn_quantizedcpu [kernel] FPGA: fn_fpga [kernel]
''') ''')
def test_computed_table_with_cpu_defaultbackend(self): def test_computed_table_with_cpu_defaultbackend(self):
@ -616,7 +616,7 @@ CompositeExplicitAutograd[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0
''') ''')
# computed dispatch table is too big, so we only check on a few entries we're interested in. # computed dispatch table is too big, so we only check on a few entries we're interested in.
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',)) extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('FPGA',))
self.assertExpectedInline(extracted_table, '''\ self.assertExpectedInline(extracted_table, '''\
Undefined: fn_defaultbackend [default backend kernel] Undefined: fn_defaultbackend [default backend kernel]
@ -627,7 +627,7 @@ AutogradOther: fn_autograd [autograd kernel]
AutogradCPU: fn_autograd [autograd kernel] AutogradCPU: fn_autograd [autograd kernel]
AutogradCUDA: fn_autograd [autograd kernel] AutogradCUDA: fn_autograd [autograd kernel]
AutogradXLA: fn_autograd [autograd kernel] AutogradXLA: fn_autograd [autograd kernel]
QuantizedCPU: fn_defaultbackend [default backend kernel] FPGA: fn_defaultbackend [default backend kernel]
''') ''')
def test_computed_table_with_cpu_autograd_math_defaultbackend(self): def test_computed_table_with_cpu_autograd_math_defaultbackend(self):
@ -808,7 +808,7 @@ key kernel
CPU fn_CPU [kernel] CPU fn_CPU [kernel]
XLA fn_XLA [kernel] XLA fn_XLA [kernel]
Lazy fn_Lazy [kernel] Lazy fn_Lazy [kernel]
QuantizedCPU fn_CompositeImplicitAutograd [math kernel] FPGA fn_CompositeImplicitAutograd [math kernel]
AutogradOther fn_CompositeImplicitAutograd [math kernel] AutogradOther fn_CompositeImplicitAutograd [math kernel]
AutogradCPU fallthrough [backend fallback] AutogradCPU fallthrough [backend fallback]
AutogradXLA fallthrough [backend fallback] AutogradXLA fallthrough [backend fallback]
@ -829,7 +829,7 @@ key kernel
CPU fn_CPU [kernel] CPU fn_CPU [kernel]
XLA fn_XLA [kernel] XLA fn_XLA [kernel]
Lazy fn_Lazy [kernel] Lazy fn_Lazy [kernel]
QuantizedCPU fn_CompositeImplicitAutograd [math kernel] FPGA fn_CompositeImplicitAutograd [math kernel]
AutogradOther fn_CompositeImplicitAutograd [math kernel] AutogradOther fn_CompositeImplicitAutograd [math kernel]
AutogradCPU fn_AutogradCPU [kernel] AutogradCPU fn_AutogradCPU [kernel]
AutogradXLA fallthrough [backend fallback] AutogradXLA fallthrough [backend fallback]
@ -864,7 +864,7 @@ key kernel
CPU fn_CPU [kernel] CPU fn_CPU [kernel]
XLA fn_XLA [kernel] XLA fn_XLA [kernel]
Lazy fn_Lazy [kernel] Lazy fn_Lazy [kernel]
QuantizedCPU fn_CompositeExplicitAutograd [default backend kernel] FPGA fn_CompositeExplicitAutograd [default backend kernel]
AutogradOther fallthrough [backend fallback] AutogradOther fallthrough [backend fallback]
AutogradCPU fn_AutogradCPU [kernel] AutogradCPU fn_AutogradCPU [kernel]
AutogradXLA fallthrough [backend fallback] AutogradXLA fallthrough [backend fallback]
@ -889,7 +889,7 @@ CompositeExplicitAutograd[alias] fn_CompositeExplicitAutograd
def test_autogradother(self): def test_autogradother(self):
dispatcher = PythonDispatcher() dispatcher = PythonDispatcher()
dispatcher.register(["CPU", "QuantizedCPU", "CompositeImplicitAutograd"]) dispatcher.register(["CPU", "FPGA", "CompositeImplicitAutograd"])
self.assertExpectedInline( self.assertExpectedInline(
dispatcher.dispatchTable(), dispatcher.dispatchTable(),
'''\ '''\
@ -900,7 +900,7 @@ key kernel
CPU fn_CPU [kernel] CPU fn_CPU [kernel]
XLA fn_CompositeImplicitAutograd [math kernel] XLA fn_CompositeImplicitAutograd [math kernel]
Lazy fn_CompositeImplicitAutograd [math kernel] Lazy fn_CompositeImplicitAutograd [math kernel]
QuantizedCPU fn_QuantizedCPU [kernel] FPGA fn_FPGA [kernel]
AutogradOther ambiguous_autogradother [ambiguous autogradother] AutogradOther ambiguous_autogradother [ambiguous autogradother]
AutogradCPU fallthrough [backend fallback] AutogradCPU fallthrough [backend fallback]
AutogradXLA fn_CompositeImplicitAutograd [math kernel] AutogradXLA fn_CompositeImplicitAutograd [math kernel]
@ -915,8 +915,8 @@ AutogradLazy fn_CompositeImplicitAutograd [math kernel]
Registered Kernels Registered Kernels
key kernel key kernel
--------------------------- ---------------------------
FPGA fn_FPGA
CPU fn_CPU CPU fn_CPU
QuantizedCPU fn_QuantizedCPU
CompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd CompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd
''' '''
) )

View File

@ -3410,21 +3410,21 @@ class TestSparseOneOff(TestCase):
def test_cuda_from_cpu(self): def test_cuda_from_cpu(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
"backend of indices \\(CUDA\\) must match backend of values \\(CPU\\)"): "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
torch.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(), torch.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(),
torch.randn(4, 4, 4), torch.randn(4, 4, 4),
[3, 4, 4]) [3, 4, 4])
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
"backend of indices \\(CUDA\\) must match backend of values \\(CPU\\)"): "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
torch.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(), torch.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(),
torch.randn(4, 4, 4, 0), torch.randn(4, 4, 4, 0),
[3, 4, 4, 0]) [3, 4, 4, 0])
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
"backend of indices \\(CUDA\\) must match backend of values \\(CPU\\)"): "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"):
torch.sparse.FloatTensor(torch.LongTensor(1, 0).cuda(), torch.sparse.FloatTensor(torch.LongTensor(1, 0).cuda(),
torch.randn(0, 4, 4, 0), torch.randn(0, 4, 4, 0),
[0, 4, 4, 0]) [0, 4, 4, 0])

View File

@ -48,58 +48,66 @@ class DispatchKey(Enum):
Undefined = 0 Undefined = 0
CatchAll = Undefined CatchAll = Undefined
CPU = auto() Dense = auto()
CUDA = auto()
HIP = auto()
FPGA = auto() FPGA = auto()
ORT = auto() ORT = auto()
XLA = auto()
Lazy = auto()
Vulkan = auto() Vulkan = auto()
Metal = auto() Metal = auto()
XPU = auto()
MKLDNN = auto() MKLDNN = auto()
OpenGL = auto() OpenGL = auto()
OpenCL = auto() OpenCL = auto()
IDEEP = auto() IDEEP = auto()
QuantizedCPU = auto() Quantized = auto()
QuantizedCUDA = auto()
QuantizedXPU = auto()
CustomRNGKeyId = auto() CustomRNGKeyId = auto()
MkldnnCPU = auto() MkldnnCPU = auto()
SparseCPU = auto() Sparse = auto()
SparseCUDA = auto()
SparseCsrCPU = auto() SparseCsrCPU = auto()
SparseCsrCUDA = auto() SparseCsrCUDA = auto()
SparseHIP = auto()
SparseXPU = auto()
NestedTensor = auto()
PrivateUse1 = auto()
PrivateUse2 = auto()
PrivateUse3 = auto()
EndOfBackendKeys = PrivateUse3
ZeroTensor = auto() ZeroTensor = auto()
Meta = auto() Meta = auto()
BackendSelect = auto() BackendSelect = auto()
Named = auto() Named = auto()
AutogradOther = auto() AutogradOther = auto()
AutogradCPU = auto() AutogradFunctionality = auto()
AutogradCUDA = auto()
AutogradXLA = auto()
AutogradLazy = auto()
AutogradNestedTensor = auto() AutogradNestedTensor = auto()
AutogradXPU = auto()
AutogradPrivateUse1 = auto()
AutogradPrivateUse2 = auto()
AutogradPrivateUse3 = auto()
Tracer = auto() Tracer = auto()
Autocast = auto() Autocast = auto()
Batched = auto() Batched = auto()
VmapMode = auto() VmapMode = auto()
TESTING_ONLY_GenericWrapper = auto() TESTING_ONLY_GenericWrapper = auto()
TESTING_ONLY_GenericMode = auto() TESTING_ONLY_GenericMode = auto()
NumDispatchKeys = auto() EndOfFunctionalityKeys = TESTING_ONLY_GenericMode
CPU = auto()
CUDA = auto()
HIP = auto()
XLA = auto()
Lazy = auto()
XPU = auto()
NestedTensor = auto()
PrivateUse1 = auto()
PrivateUse2 = auto()
PrivateUse3 = auto()
QuantizedCPU = auto()
QuantizedCUDA = auto()
QuantizedXPU = auto()
SparseCPU = auto()
SparseCUDA = auto()
SparseHIP = auto()
SparseXPU = auto()
AutogradCPU = auto()
AutogradCUDA = auto()
AutogradXLA = auto()
AutogradLazy = auto()
AutogradXPU = auto()
AutogradPrivateUse1 = auto()
AutogradPrivateUse2 = auto()
AutogradPrivateUse3 = auto()
Autograd = auto() Autograd = auto()
CompositeImplicitAutograd = auto() CompositeImplicitAutograd = auto()
CompositeExplicitAutograd = auto() CompositeExplicitAutograd = auto()

View File

@ -15,9 +15,9 @@ keys for a single example of each use case. These use cases are listed below:
- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference & - CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference &
autograd kernel in pytorch core library. autograd kernel in pytorch core library.
E.g. CPU, CUDA E.g. CPU, CUDA
- QuantizedCPU/AutogradOther: represents in-tree backends which we usually have backend specific - FPGA/AutogradOther: represents in-tree backends which we usually have backend specific
inference kernels, but they share the same autograd kernel specified in AutogradOther. inference kernels, but they share the same autograd kernel specified in AutogradOther.
E.g. QuantizedCPU, QuantizedCUDA E.g. FPGA, SparseCsrCPU
- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd - XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd
kernel defined in pytorch core library. Backend owner is responsible for registering both kernel defined in pytorch core library. Backend owner is responsible for registering both
inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support. inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support.
@ -53,7 +53,7 @@ class PythonDispatcher:
name = "foo" name = "foo"
runtime_keys = [ runtime_keys = [
"CPU", "AutogradCPU", "CPU", "AutogradCPU",
"QuantizedCPU", "AutogradOther", "FPGA", "AutogradOther",
"XLA", "AutogradXLA", "XLA", "AutogradXLA",
"Lazy", "AutogradLazy", "Lazy", "AutogradLazy",
] ]