mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[POC] Skip autocast kernels on non-CUDA tensors
We use the same trick as AutogradCUDA but applied to Autocast, and also
introduce a new excluded by default global set to ensure that this
is not turned on by default.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
ghstack-source-id: f993baf945
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56644
This commit is contained in:
parent
28f52649d8
commit
d3ca2f21ad
|
|
@ -13,11 +13,11 @@ namespace at {
|
|||
namespace autocast {
|
||||
|
||||
bool is_enabled() {
|
||||
return c10::impl::tls_is_dispatch_key_included(DispatchKey::Autocast);
|
||||
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA);
|
||||
}
|
||||
|
||||
void set_enabled(bool new_enabled) {
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::Autocast, new_enabled);
|
||||
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCUDA, !new_enabled);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ enum class DispatchKey : uint8_t {
|
|||
|
||||
// Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed
|
||||
// and inputs are saved for backward in the post-autocast type.
|
||||
Autocast,
|
||||
AutocastCUDA,
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
|
||||
// There are a number of alternative modes which may want to handle before
|
||||
|
|
@ -290,6 +290,7 @@ enum class DispatchKey : uint8_t {
|
|||
PrivateUse1_PreAutograd = AutogradPrivateUse1,
|
||||
PrivateUse2_PreAutograd = AutogradPrivateUse2,
|
||||
PrivateUse3_PreAutograd = AutogradPrivateUse3,
|
||||
Autocast = AutocastCUDA,
|
||||
};
|
||||
|
||||
// Note [Private use DispatchKey]
|
||||
|
|
|
|||
|
|
@ -70,6 +70,17 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
|
|||
}
|
||||
}
|
||||
|
||||
DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
|
||||
switch (t) {
|
||||
//case DispatchKey::CPU:
|
||||
// return DispatchKeySet(DispatchKey::AutocastCPU);
|
||||
case DispatchKey::CUDA:
|
||||
return DispatchKeySet(DispatchKey::AutocastCUDA);
|
||||
default:
|
||||
return DispatchKeySet();
|
||||
}
|
||||
}
|
||||
|
||||
DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) {
|
||||
return DispatchKeySet({
|
||||
DispatchKey::InplaceOrView, getAutogradKeyFromBackend(t)});
|
||||
|
|
|
|||
|
|
@ -213,6 +213,11 @@ constexpr DispatchKeySet default_included_set = DispatchKeySet({
|
|||
DispatchKey::InplaceOrView,
|
||||
});
|
||||
|
||||
constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
|
||||
// DispatchKey::AutocastCPU,
|
||||
DispatchKey::AutocastCUDA,
|
||||
});
|
||||
|
||||
constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView =
|
||||
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView);
|
||||
|
||||
|
|
@ -266,6 +271,9 @@ C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
|
|||
// Returns a DispatchKeySet of autograd related keys mapped to backend.
|
||||
C10_API DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t);
|
||||
|
||||
// Returns a DispatchKeySet of autocast related keys mapped to backend.
|
||||
C10_API DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t);
|
||||
|
||||
// This API exists because we have a use case for checking
|
||||
// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined)
|
||||
// in OperatorEntry.cpp but we disallow it in has() API.
|
||||
|
|
|
|||
|
|
@ -106,6 +106,12 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::
|
|||
|
||||
bool inference_mode = c10::InferenceMode::is_enabled();
|
||||
|
||||
// TODO: be more explicit about the full key set at call sites so we
|
||||
// don't have to keep recomputing it here
|
||||
DispatchKey k = key_set.highestPriorityBackendTypeId();
|
||||
|
||||
key_set = key_set | getAutocastRelatedKeySetFromBackend(k);
|
||||
|
||||
// Inference tensor doesn't have autograd related keys.
|
||||
if (inference_mode) {
|
||||
// See Note [Expected TLS state in InferenceMode] for why we exclude Autograd & InplaceOrView keys.
|
||||
|
|
@ -115,7 +121,6 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::
|
|||
} else {
|
||||
// TODO: Ideally we only add AutogradBackend key when the tensor requires grad.
|
||||
// See Note [Dream: skip VariableType kernel when requires_grad=false]
|
||||
DispatchKey k = key_set.highestPriorityBackendTypeId();
|
||||
key_set_ = key_set | getAutogradRelatedKeySetFromBackend(k);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -39,14 +39,14 @@ struct C10_API PODLocalDispatchKeySet {
|
|||
return DispatchKeySet(DispatchKeySet::RAW, included_) ^ c10::default_included_set;
|
||||
}
|
||||
DispatchKeySet excluded() const {
|
||||
return DispatchKeySet(DispatchKeySet::RAW, excluded_);
|
||||
return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^ c10::default_excluded_set;
|
||||
}
|
||||
|
||||
void set_included(DispatchKeySet x) {
|
||||
included_ = (x ^ c10::default_included_set).raw_repr();
|
||||
}
|
||||
void set_excluded(DispatchKeySet x) {
|
||||
excluded_ = x.raw_repr();
|
||||
excluded_ = (x ^ c10::default_excluded_set).raw_repr();
|
||||
}
|
||||
};
|
||||
static_assert(std::is_pod<PODLocalDispatchKeySet>::value, "PODLocalDispatchKeySet must be a POD type.");
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user