pytorch/c10/core/DispatchKeySet.cpp
Edward Z. Yang d3ca2f21ad [POC] Skip autocast kernels on non-CUDA tensors
We use the same trick as AutogradCUDA but applied to Autocast, and also
introduce a new excluded by default global set to ensure that this
is not turned on by default.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

ghstack-source-id: f993baf945
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56644
2021-04-21 18:48:37 -04:00

120 lines
3.7 KiB
C++

#include <c10/core/DispatchKeySet.h>
namespace c10 {
// backend_dispatch_keyset should include all runtime backend keys.
// Alias key DispatchKey::CompositeExplicitAutograd maps to backend_dispatch_keyset
// NestedTensor has been explicitly removed due to incompatibility with some
// kernels, such as structured kernels, that use the DefaultBackend key.
constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
DispatchKeySet({
DispatchKey::CPU,
DispatchKey::CUDA,
DispatchKey::XLA,
DispatchKey::XPU,
DispatchKey::PrivateUse1,
DispatchKey::PrivateUse2,
DispatchKey::PrivateUse3,
DispatchKey::MLC,
DispatchKey::Meta,
});
bool isBackendDispatchKey(DispatchKey t) {
return t != DispatchKey::Undefined && backend_dispatch_keyset.has(t);
}
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and autograd_dispatch_keyset
// Alias key DispatchKey::CompositeImplicitAutograd maps to math_dispatch_keyset.
constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset;
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) {
case DispatchKey::Autograd:
return autograd_dispatch_keyset;
case DispatchKey::CompositeImplicitAutograd:
return math_dispatch_keyset;
case DispatchKey::CompositeExplicitAutograd:
return backend_dispatch_keyset;
default:
return DispatchKeySet(t);
}
}
// for a given autograd key, return the (guaranteed nonempty) set of associated backend keys.
// for a non-autograd key, return the empty keyset.
DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
switch (t) {
case DispatchKey::AutogradCPU:
return DispatchKeySet(DispatchKey::CPU);
case DispatchKey::AutogradCUDA:
return DispatchKeySet(DispatchKey::CUDA);
case DispatchKey::AutogradXLA:
return DispatchKeySet(DispatchKey::XLA);
case DispatchKey::AutogradMLC:
return DispatchKeySet(DispatchKey::MLC);
case DispatchKey::AutogradNestedTensor:
return DispatchKeySet(DispatchKey::NestedTensor);
case DispatchKey::AutogradXPU:
return DispatchKeySet(DispatchKey::XPU);
case DispatchKey::AutogradPrivateUse1:
return DispatchKeySet(DispatchKey::PrivateUse1);
case DispatchKey::AutogradPrivateUse2:
return DispatchKeySet(DispatchKey::PrivateUse2);
case DispatchKey::AutogradPrivateUse3:
return DispatchKeySet(DispatchKey::PrivateUse3);
case DispatchKey::AutogradOther:
return autogradother_backends;
default:
return DispatchKeySet();
}
}
DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
switch (t) {
//case DispatchKey::CPU:
// return DispatchKeySet(DispatchKey::AutocastCPU);
case DispatchKey::CUDA:
return DispatchKeySet(DispatchKey::AutocastCUDA);
default:
return DispatchKeySet();
}
}
DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) {
return DispatchKeySet({
DispatchKey::InplaceOrView, getAutogradKeyFromBackend(t)});
}
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
return k != DispatchKey::Undefined && getRuntimeDispatchKeySet(alias).has(k);
}
std::string toString(DispatchKeySet ts) {
std::stringstream ss;
ss << ts;
return ss.str();
}
std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
if (ts.empty()) {
os << "DispatchKeySet()";
return os;
}
os << "DispatchKeySet(";
DispatchKey tid;
bool first = true;
while ((tid = ts.highestPriorityTypeId()) != DispatchKey::Undefined) {
if (!first) {
os << ", ";
}
os << tid;
ts = ts.remove(tid);
first = false;
}
os << ")";
return os;
}
}