#pragma once #include #include #include #include #include namespace c10 { // A RAII, thread local (!) guard that enables or disables inference mode upon // construction, and sets it back to the original value upon destruction. struct C10_API InferenceMode { // Note [Expected TLS state in InferenceMode]: // InferenceMode: ADInplaceOrView not in // raw_local_dispatch_key_set.included(), // Autograd in raw_local_dispatch_key_set.excluded() // GradMode is disabled. // NormalMode: ADInplaceOrView in raw_local_dispatch_key_set.included(), // Autograd not in raw_local_dispatch_key_set.excluded() // GradMode is enabled by default unless toggled manually // through other APIs, e.g. NoGradGuard. // // Invariant: // - ADInplaceOrView is never in the excluded set // - Autograd is never in the included set // - Setting InferenceMode will set GradMode accordingly, but not vice versa. // // 1. Why do we put ADInplaceOrView in included set outside InferenceMode? // // Inplace update to inference tensor outside InferenceMode is not // allowed. See Note [Inplace update inference tensor] for more details. // Without going through ADInplaceOrView kernel, we cannot throw error // for `inference_tensor.add_(1)` case. // // 2. Why not put ADInplaceOrView in the excluded set inside InferenceMode? // // For example: // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true); // torch::Tensor k = a + 2; // { // c10::InferenceMode guard(true); // k.add_(2); // } // `k.add_(2)` still need to go through ADInplaceOrView kernel so that it's // prepared for future autograd. // // 3. Why does setting InferenceMode also set GradMode? // // This is required since InferenceMode is a faster and more restrictive // version of NoGradGuard. All runtime checks using GradMode::is_enabled() // are applicable to InferenceMode as well, e.g. // `tensorTypeInCurrentExecutionContext` in interpreter.cpp. InferenceMode(bool enabled = true) : prev_mode(AutogradState::get_tls_state()), prev_keyset(c10::impl::tls_local_dispatch_key_set()) { // Enabling inference mode means disabling grad modes // And disabling inference mode means enabling grad modes AutogradState::set_tls_state(AutogradState( /* grad_mode */ !enabled, /* inference_mode */ enabled, /* fw_grad_mode */ !enabled, /* multithreading_enabled*/ !enabled)); DispatchKeySet included = enabled ? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView) : prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView); DispatchKeySet excluded = enabled ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset) : (prev_keyset.excluded_ - c10::autograd_dispatch_keyset); c10::impl::PODLocalDispatchKeySet cur_keyset{}; cur_keyset.set_included(included); cur_keyset.set_excluded(excluded); c10::impl::_force_tls_local_dispatch_key_set(cur_keyset); } InferenceMode(const InferenceMode&) = delete; InferenceMode(InferenceMode&&) = delete; InferenceMode& operator=(const InferenceMode&) = delete; InferenceMode& operator=(InferenceMode&&) = delete; ~InferenceMode() { AutogradState::set_tls_state(prev_mode); c10::impl::_force_tls_local_dispatch_key_set(prev_keyset); } static bool is_enabled(); private: AutogradState prev_mode; c10::impl::LocalDispatchKeySet prev_keyset; }; } // namespace c10