[POC] Skip autocast kernels on non-CUDA tensors

We use the same trick as AutogradCUDA but applied to Autocast, and also
introduce a new excluded by default global set to ensure that this
is not turned on by default.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

ghstack-source-id: f993baf945
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56644
This commit is contained in:
Edward Z. Yang 2021-04-21 18:48:37 -04:00
parent 28f52649d8
commit d3ca2f21ad
6 changed files with 31 additions and 6 deletions

View File

@ -13,11 +13,11 @@ namespace at {
namespace autocast {
bool is_enabled() {
return c10::impl::tls_is_dispatch_key_included(DispatchKey::Autocast);
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA);
}
void set_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Autocast, new_enabled);
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCUDA, !new_enabled);
}
namespace {

View File

@ -226,7 +226,7 @@ enum class DispatchKey : uint8_t {
// Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed
// and inputs are saved for backward in the post-autocast type.
Autocast,
AutocastCUDA,
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// There are a number of alternative modes which may want to handle before
@ -290,6 +290,7 @@ enum class DispatchKey : uint8_t {
PrivateUse1_PreAutograd = AutogradPrivateUse1,
PrivateUse2_PreAutograd = AutogradPrivateUse2,
PrivateUse3_PreAutograd = AutogradPrivateUse3,
Autocast = AutocastCUDA,
};
// Note [Private use DispatchKey]

View File

@ -70,6 +70,17 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
}
}
DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
switch (t) {
//case DispatchKey::CPU:
// return DispatchKeySet(DispatchKey::AutocastCPU);
case DispatchKey::CUDA:
return DispatchKeySet(DispatchKey::AutocastCUDA);
default:
return DispatchKeySet();
}
}
DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) {
return DispatchKeySet({
DispatchKey::InplaceOrView, getAutogradKeyFromBackend(t)});

View File

@ -213,6 +213,11 @@ constexpr DispatchKeySet default_included_set = DispatchKeySet({
DispatchKey::InplaceOrView,
});
constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
// DispatchKey::AutocastCPU,
DispatchKey::AutocastCUDA,
});
constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView =
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView);
@ -266,6 +271,9 @@ C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
// Returns a DispatchKeySet of autograd related keys mapped to backend.
C10_API DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t);
// Returns a DispatchKeySet of autocast related keys mapped to backend.
C10_API DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t);
// This API exists because we have a use case for checking
// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined)
// in OperatorEntry.cpp but we disallow it in has() API.

View File

@ -106,6 +106,12 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::
bool inference_mode = c10::InferenceMode::is_enabled();
// TODO: be more explicit about the full key set at call sites so we
// don't have to keep recomputing it here
DispatchKey k = key_set.highestPriorityBackendTypeId();
key_set = key_set | getAutocastRelatedKeySetFromBackend(k);
// Inference tensor doesn't have autograd related keys.
if (inference_mode) {
// See Note [Expected TLS state in InferenceMode] for why we exclude Autograd & InplaceOrView keys.
@ -115,7 +121,6 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::
} else {
// TODO: Ideally we only add AutogradBackend key when the tensor requires grad.
// See Note [Dream: skip VariableType kernel when requires_grad=false]
DispatchKey k = key_set.highestPriorityBackendTypeId();
key_set_ = key_set | getAutogradRelatedKeySetFromBackend(k);
}

View File

@ -39,14 +39,14 @@ struct C10_API PODLocalDispatchKeySet {
return DispatchKeySet(DispatchKeySet::RAW, included_) ^ c10::default_included_set;
}
DispatchKeySet excluded() const {
return DispatchKeySet(DispatchKeySet::RAW, excluded_);
return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^ c10::default_excluded_set;
}
void set_included(DispatchKeySet x) {
included_ = (x ^ c10::default_included_set).raw_repr();
}
void set_excluded(DispatchKeySet x) {
excluded_ = x.raw_repr();
excluded_ = (x ^ c10::default_excluded_set).raw_repr();
}
};
static_assert(std::is_pod<PODLocalDispatchKeySet>::value, "PODLocalDispatchKeySet must be a POD type.");