mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43387 Test Plan: Imported from OSS Reviewed By: ailzhang Differential Revision: D23258687 Pulled By: bhosmer fbshipit-source-id: 3718f74fc7324db027f87eda0b90893a960aa56e
44 lines
863 B
C++
44 lines
863 B
C++
#include <c10/core/DispatchKeySet.h>
|
|
|
|
namespace c10 {
|
|
|
|
static DispatchKeySet autograd_dispatch_keys{
|
|
DispatchKey::Autograd,
|
|
DispatchKey::AutogradXLA,
|
|
DispatchKey::PrivateUse1_PreAutograd,
|
|
DispatchKey::PrivateUse2_PreAutograd,
|
|
DispatchKey::PrivateUse3_PreAutograd,
|
|
};
|
|
|
|
DispatchKeySet AutogradDispatchKeys() {
|
|
return autograd_dispatch_keys;
|
|
}
|
|
|
|
std::string toString(DispatchKeySet ts) {
|
|
std::stringstream ss;
|
|
ss << ts;
|
|
return ss.str();
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
|
|
if (ts.empty()) {
|
|
os << "DispatchKeySet()";
|
|
return os;
|
|
}
|
|
os << "DispatchKeySet(";
|
|
DispatchKey tid;
|
|
bool first = true;
|
|
while ((tid = ts.highestPriorityTypeId()) != DispatchKey::Undefined) {
|
|
if (!first) {
|
|
os << ", ";
|
|
}
|
|
os << tid;
|
|
ts = ts.remove(tid);
|
|
first = false;
|
|
}
|
|
os << ")";
|
|
return os;
|
|
}
|
|
|
|
}
|