From d3ca2f21ad56d08a2528ddf5747b4680678657d3 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 21 Apr 2021 18:48:37 -0400 Subject: [PATCH] [POC] Skip autocast kernels on non-CUDA tensors We use the same trick as AutogradCUDA but applied to Autocast, and also introduce a new excluded by default global set to ensure that this is not turned on by default. Signed-off-by: Edward Z. Yang ghstack-source-id: f993baf945b866d2cd3c1ab857f15973a4c21696 Pull Request resolved: https://github.com/pytorch/pytorch/pull/56644 --- aten/src/ATen/autocast_mode.cpp | 4 ++-- c10/core/DispatchKey.h | 3 ++- c10/core/DispatchKeySet.cpp | 11 +++++++++++ c10/core/DispatchKeySet.h | 8 ++++++++ c10/core/TensorImpl.cpp | 7 ++++++- c10/core/impl/LocalDispatchKeySet.h | 4 ++-- 6 files changed, 31 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 1a366595a4a..0bf6f11a027 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -13,11 +13,11 @@ namespace at { namespace autocast { bool is_enabled() { - return c10::impl::tls_is_dispatch_key_included(DispatchKey::Autocast); + return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA); } void set_enabled(bool new_enabled) { - c10::impl::tls_set_dispatch_key_included(DispatchKey::Autocast, new_enabled); + c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCUDA, !new_enabled); } namespace { diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 2dcf989932d..1456ce71ee9 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -226,7 +226,7 @@ enum class DispatchKey : uint8_t { // Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed // and inputs are saved for backward in the post-autocast type. - Autocast, + AutocastCUDA, // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // // There are a number of alternative modes which may want to handle before @@ -290,6 +290,7 @@ enum class DispatchKey : uint8_t { PrivateUse1_PreAutograd = AutogradPrivateUse1, PrivateUse2_PreAutograd = AutogradPrivateUse2, PrivateUse3_PreAutograd = AutogradPrivateUse3, + Autocast = AutocastCUDA, }; // Note [Private use DispatchKey] diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index e24de613ee5..0ea9c1cfdd0 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -70,6 +70,17 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { } } +DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) { + switch (t) { + //case DispatchKey::CPU: + // return DispatchKeySet(DispatchKey::AutocastCPU); + case DispatchKey::CUDA: + return DispatchKeySet(DispatchKey::AutocastCUDA); + default: + return DispatchKeySet(); + } +} + DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t) { return DispatchKeySet({ DispatchKey::InplaceOrView, getAutogradKeyFromBackend(t)}); diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index b75caed7d78..abee4e984f7 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -213,6 +213,11 @@ constexpr DispatchKeySet default_included_set = DispatchKeySet({ DispatchKey::InplaceOrView, }); +constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ + // DispatchKey::AutocastCPU, + DispatchKey::AutocastCUDA, +}); + constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView = autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView); @@ -266,6 +271,9 @@ C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t); // Returns a DispatchKeySet of autograd related keys mapped to backend. C10_API DispatchKeySet getAutogradRelatedKeySetFromBackend(DispatchKey t); +// Returns a DispatchKeySet of autocast related keys mapped to backend. +C10_API DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t); + // This API exists because we have a use case for checking // getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined) // in OperatorEntry.cpp but we disallow it in has() API. diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 8612200ee29..d3d75c5cc35 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -106,6 +106,12 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2:: bool inference_mode = c10::InferenceMode::is_enabled(); + // TODO: be more explicit about the full key set at call sites so we + // don't have to keep recomputing it here + DispatchKey k = key_set.highestPriorityBackendTypeId(); + + key_set = key_set | getAutocastRelatedKeySetFromBackend(k); + // Inference tensor doesn't have autograd related keys. if (inference_mode) { // See Note [Expected TLS state in InferenceMode] for why we exclude Autograd & InplaceOrView keys. @@ -115,7 +121,6 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2:: } else { // TODO: Ideally we only add AutogradBackend key when the tensor requires grad. // See Note [Dream: skip VariableType kernel when requires_grad=false] - DispatchKey k = key_set.highestPriorityBackendTypeId(); key_set_ = key_set | getAutogradRelatedKeySetFromBackend(k); } diff --git a/c10/core/impl/LocalDispatchKeySet.h b/c10/core/impl/LocalDispatchKeySet.h index 064d391055d..b18b4d4de1d 100644 --- a/c10/core/impl/LocalDispatchKeySet.h +++ b/c10/core/impl/LocalDispatchKeySet.h @@ -39,14 +39,14 @@ struct C10_API PODLocalDispatchKeySet { return DispatchKeySet(DispatchKeySet::RAW, included_) ^ c10::default_included_set; } DispatchKeySet excluded() const { - return DispatchKeySet(DispatchKeySet::RAW, excluded_); + return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^ c10::default_excluded_set; } void set_included(DispatchKeySet x) { included_ = (x ^ c10::default_included_set).raw_repr(); } void set_excluded(DispatchKeySet x) { - excluded_ = x.raw_repr(); + excluded_ = (x ^ c10::default_excluded_set).raw_repr(); } }; static_assert(std::is_pod::value, "PODLocalDispatchKeySet must be a POD type.");