mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45718 Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D24165892 Pulled By: ailzhang fbshipit-source-id: ed28bf62b7c6320d966fd10b7a44b14efffe2f62
86 lines
2.4 KiB
C++
86 lines
2.4 KiB
C++
#include <c10/core/DispatchKeySet.h>
|
|
|
|
namespace c10 {
|
|
|
|
// backend_dispatch_keyset should include all runtime backend keys.
|
|
// Alias key DispatchKey::DefaultBackend maps to backend_dispatch_keyset
|
|
constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends | DispatchKeySet({
|
|
DispatchKey::CPU,
|
|
DispatchKey::CUDA,
|
|
DispatchKey::XLA,
|
|
DispatchKey::PrivateUse1,
|
|
DispatchKey::PrivateUse2,
|
|
DispatchKey::PrivateUse3,
|
|
});
|
|
|
|
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and autograd_dispatch_keyset
|
|
// Alias key DispatchKey::Math maps to math_dispatch_keyset.
|
|
constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset;
|
|
|
|
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
|
|
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
|
|
switch (t) {
|
|
case DispatchKey::Autograd:
|
|
return autograd_dispatch_keyset;
|
|
case DispatchKey::Math:
|
|
return math_dispatch_keyset;
|
|
case DispatchKey::DefaultBackend:
|
|
return backend_dispatch_keyset;
|
|
default:
|
|
return DispatchKeySet(t);
|
|
}
|
|
}
|
|
|
|
DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
|
|
switch (t) {
|
|
case DispatchKey::AutogradCPU:
|
|
return DispatchKeySet(DispatchKey::CPU);
|
|
case DispatchKey::AutogradCUDA:
|
|
return DispatchKeySet(DispatchKey::CUDA);
|
|
case DispatchKey::AutogradXLA:
|
|
return DispatchKeySet(DispatchKey::XLA);
|
|
case DispatchKey::AutogradPrivateUse1:
|
|
return DispatchKeySet(DispatchKey::PrivateUse1);
|
|
case DispatchKey::AutogradPrivateUse2:
|
|
return DispatchKeySet(DispatchKey::PrivateUse2);
|
|
case DispatchKey::AutogradPrivateUse3:
|
|
return DispatchKeySet(DispatchKey::PrivateUse3);
|
|
case DispatchKey::AutogradOther:
|
|
return autogradother_backends;
|
|
default:
|
|
return DispatchKeySet();
|
|
}
|
|
}
|
|
|
|
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
|
|
return k != DispatchKey::Undefined && getRuntimeDispatchKeySet(alias).has(k);
|
|
}
|
|
|
|
std::string toString(DispatchKeySet ts) {
|
|
std::stringstream ss;
|
|
ss << ts;
|
|
return ss.str();
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
|
|
if (ts.empty()) {
|
|
os << "DispatchKeySet()";
|
|
return os;
|
|
}
|
|
os << "DispatchKeySet(";
|
|
DispatchKey tid;
|
|
bool first = true;
|
|
while ((tid = ts.highestPriorityTypeId()) != DispatchKey::Undefined) {
|
|
if (!first) {
|
|
os << ", ";
|
|
}
|
|
os << tid;
|
|
ts = ts.remove(tid);
|
|
first = false;
|
|
}
|
|
os << ")";
|
|
return os;
|
|
}
|
|
|
|
}
|