pytorch/c10/core/DispatchKeySet.cpp
leslie-fang-intel 0ede83db7a enable torch.cpu.amp.autocast (#57386)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57386

Here is the PR for what's discussed in the RFC https://github.com/pytorch/pytorch/issues/55374 to enable the autocast for CPU device. Currently, this PR only enable BF16 as the lower precision datatype.

Changes:
1.  Enable new API `torch.cpu.amp.autocast` for autocast on CPU device: include the python API, C++ API, new Dispatchkey etc.
2.  Consolidate the implementation for each cast policy sharing between CPU and GPU devices.
3.  Add the operation lists to corresponding cast policy for cpu autocast.

Test Plan: Imported from OSS

Reviewed By: soulitzer

Differential Revision: D28572219

Pulled By: ezyang

fbshipit-source-id: db3db509973b16a5728ee510b5e1ee716b03a152
2021-05-20 17:48:36 -07:00

126 lines
3.9 KiB
C++

#include <c10/core/DispatchKeySet.h>
namespace c10 {
// backend_dispatch_keyset should include all runtime backend keys.
// Alias key DispatchKey::CompositeExplicitAutograd maps to
// backend_dispatch_keyset NestedTensor has been explicitly removed due to
// incompatibility with some kernels, such as structured kernels, that use the
// DefaultBackend key.
constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
DispatchKeySet({
DispatchKey::CPU,
DispatchKey::CUDA,
DispatchKey::XLA,
DispatchKey::XPU,
DispatchKey::PrivateUse1,
DispatchKey::PrivateUse2,
DispatchKey::PrivateUse3,
DispatchKey::MLC,
DispatchKey::HPU,
DispatchKey::Meta,
});
bool isBackendDispatchKey(DispatchKey t) {
return t != DispatchKey::Undefined && backend_dispatch_keyset.has(t);
}
// math_dispatch_keyset contains all keys in backend_dispatch_keyset and
// autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd
// maps to math_dispatch_keyset.
constexpr DispatchKeySet math_dispatch_keyset =
backend_dispatch_keyset | autograd_dispatch_keyset;
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) {
case DispatchKey::Autograd:
return autograd_dispatch_keyset;
case DispatchKey::CompositeImplicitAutograd:
return math_dispatch_keyset;
case DispatchKey::CompositeExplicitAutograd:
return backend_dispatch_keyset;
default:
return DispatchKeySet(t);
}
}
// for a given autograd key, return the (guaranteed nonempty) set of associated
// backend keys. for a non-autograd key, return the empty keyset.
DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
switch (t) {
case DispatchKey::AutogradCPU:
return DispatchKeySet(DispatchKey::CPU);
case DispatchKey::AutogradCUDA:
return DispatchKeySet(DispatchKey::CUDA);
case DispatchKey::AutogradXLA:
return DispatchKeySet(DispatchKey::XLA);
case DispatchKey::AutogradMLC:
return DispatchKeySet(DispatchKey::MLC);
case DispatchKey::AutogradHPU:
return DispatchKeySet(DispatchKey::HPU);
case DispatchKey::AutogradNestedTensor:
return DispatchKeySet(DispatchKey::NestedTensor);
case DispatchKey::AutogradXPU:
return DispatchKeySet(DispatchKey::XPU);
case DispatchKey::AutogradPrivateUse1:
return DispatchKeySet(DispatchKey::PrivateUse1);
case DispatchKey::AutogradPrivateUse2:
return DispatchKeySet(DispatchKey::PrivateUse2);
case DispatchKey::AutogradPrivateUse3:
return DispatchKeySet(DispatchKey::PrivateUse3);
case DispatchKey::AutogradOther:
return autogradother_backends;
default:
return DispatchKeySet();
}
}
DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
switch (t) {
case DispatchKey::CPU:
return DispatchKeySet(DispatchKey::AutocastCPU);
case DispatchKey::CUDA:
return DispatchKeySet(DispatchKey::AutocastCUDA);
default:
return DispatchKeySet();
}
}
DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) {
return DispatchKeySet(
{DispatchKey::ADInplaceOrView, getAutogradKeyFromBackend(t)});
}
bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
return k != DispatchKey::Undefined && getRuntimeDispatchKeySet(alias).has(k);
}
std::string toString(DispatchKeySet ts) {
std::stringstream ss;
ss << ts;
return ss.str();
}
std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
if (ts.empty()) {
os << "DispatchKeySet()";
return os;
}
os << "DispatchKeySet(";
DispatchKey tid;
bool first = true;
while ((tid = ts.highestPriorityTypeId()) != DispatchKey::Undefined) {
if (!first) {
os << ", ";
}
os << tid;
ts = ts.remove(tid);
first = false;
}
os << ")";
return os;
}
} // namespace c10