mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
We use the same trick as AutogradCUDA but applied to Autocast, and also
introduce a new excluded by default global set to ensure that this
is not turned on by default.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
ghstack-source-id: f993baf945
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56644
120 lines
3.7 KiB
C++
120 lines
3.7 KiB
C++
#include <c10/core/DispatchKeySet.h>
|
|
|
|
namespace c10 {
|
|
|
|
// backend_dispatch_keyset should include all runtime backend keys.
|
|
// Alias key DispatchKey::CompositeExplicitAutograd maps to backend_dispatch_keyset
|
|
// NestedTensor has been explicitly removed due to incompatibility with some
|
|
// kernels, such as structured kernels, that use the DefaultBackend key.
|
|
constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
|
|
DispatchKeySet({
|
|
DispatchKey::CPU,
|
|
DispatchKey::CUDA,
|
|
DispatchKey::XLA,
|
|
DispatchKey::XPU,
|
|
DispatchKey::PrivateUse1,
|
|
DispatchKey::PrivateUse2,
|
|
DispatchKey::PrivateUse3,
|
|
DispatchKey::MLC,
|
|
DispatchKey::Meta,
|
|
});
|
|
|
|
bool isBackendDispatchKey(DispatchKey t) {
|
|
return t != DispatchKey::Undefined && backend_dispatch_keyset.has(t);
|
|
}
|
|
|
|
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and autograd_dispatch_keyset
|
|
// Alias key DispatchKey::CompositeImplicitAutograd maps to math_dispatch_keyset.
|
|
constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset;
|
|
|
|
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
|
|
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
|
|
switch (t) {
|
|
case DispatchKey::Autograd:
|
|
return autograd_dispatch_keyset;
|
|
case DispatchKey::CompositeImplicitAutograd:
|
|
return math_dispatch_keyset;
|
|
case DispatchKey::CompositeExplicitAutograd:
|
|
return backend_dispatch_keyset;
|
|
default:
|
|
return DispatchKeySet(t);
|
|
}
|
|
}
|
|
|
|
// for a given autograd key, return the (guaranteed nonempty) set of associated backend keys.
|
|
// for a non-autograd key, return the empty keyset.
|
|
DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
|
|
switch (t) {
|
|
case DispatchKey::AutogradCPU:
|
|
return DispatchKeySet(DispatchKey::CPU);
|
|
case DispatchKey::AutogradCUDA:
|
|
return DispatchKeySet(DispatchKey::CUDA);
|
|
case DispatchKey::AutogradXLA:
|
|
return DispatchKeySet(DispatchKey::XLA);
|
|
case DispatchKey::AutogradMLC:
|
|
return DispatchKeySet(DispatchKey::MLC);
|
|
case DispatchKey::AutogradNestedTensor:
|
|
return DispatchKeySet(DispatchKey::NestedTensor);
|
|
case DispatchKey::AutogradXPU:
|
|
return DispatchKeySet(DispatchKey::XPU);
|
|
case DispatchKey::AutogradPrivateUse1:
|
|
return DispatchKeySet(DispatchKey::PrivateUse1);
|
|
case DispatchKey::AutogradPrivateUse2:
|
|
return DispatchKeySet(DispatchKey::PrivateUse2);
|
|
case DispatchKey::AutogradPrivateUse3:
|
|
return DispatchKeySet(DispatchKey::PrivateUse3);
|
|
case DispatchKey::AutogradOther:
|
|
return autogradother_backends;
|
|
default:
|
|
return DispatchKeySet();
|
|
}
|
|
}
|
|
|
|
DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
|
|
switch (t) {
|
|
//case DispatchKey::CPU:
|
|
// return DispatchKeySet(DispatchKey::AutocastCPU);
|
|
case DispatchKey::CUDA:
|
|
return DispatchKeySet(DispatchKey::AutocastCUDA);
|
|
default:
|
|
return DispatchKeySet();
|
|
}
|
|
}
|
|
|
|
DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) {
|
|
return DispatchKeySet({
|
|
DispatchKey::InplaceOrView, getAutogradKeyFromBackend(t)});
|
|
}
|
|
|
|
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
|
|
return k != DispatchKey::Undefined && getRuntimeDispatchKeySet(alias).has(k);
|
|
}
|
|
|
|
std::string toString(DispatchKeySet ts) {
|
|
std::stringstream ss;
|
|
ss << ts;
|
|
return ss.str();
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
|
|
if (ts.empty()) {
|
|
os << "DispatchKeySet()";
|
|
return os;
|
|
}
|
|
os << "DispatchKeySet(";
|
|
DispatchKey tid;
|
|
bool first = true;
|
|
while ((tid = ts.highestPriorityTypeId()) != DispatchKey::Undefined) {
|
|
if (!first) {
|
|
os << ", ";
|
|
}
|
|
os << tid;
|
|
ts = ts.remove(tid);
|
|
first = false;
|
|
}
|
|
os << ")";
|
|
return os;
|
|
}
|
|
|
|
}
|