mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: This is a re-attempt to land the iwyu header changes, by taking the diff from [PR 100304](https://github.com/pytorch/pytorch/pull/100304), and adding the bare minimal changes to make the diff build corectly in the internal builds. X-link: https://github.com/facebookresearch/pytorch3d/pull/1541 X-link: https://github.com/fairinternal/pytorch3d/pull/44 - Re-work D45769819 to fix header inclusions in c10 Test Plan: ``` buck2 build --no-remote-cache mode/dev-nosan //caffe2/c10/... buck2 build --no-remote-cache mode/dev-nosan //deeplearning/fbgemm/fbgemm_gpu/... buck2 build mode/dev-nosan //vision/fair/pytorch3d/pytorch3d:_C ``` Reviewed By: malfet Differential Revision: D45920611 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101846 Approved by: https://github.com/malfet, https://github.com/Skylion007
85 lines
3.4 KiB
C++
85 lines
3.4 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/AutogradState.h>
|
|
#include <c10/core/impl/LocalDispatchKeySet.h>
|
|
#include <c10/macros/Export.h>
|
|
|
|
namespace c10 {
|
|
|
|
// A RAII, thread local (!) guard that enables or disables inference mode upon
|
|
// construction, and sets it back to the original value upon destruction.
|
|
struct C10_API InferenceMode {
|
|
// Note [Expected TLS state in InferenceMode]:
|
|
// InferenceMode: ADInplaceOrView not in
|
|
// raw_local_dispatch_key_set.included(),
|
|
// Autograd in raw_local_dispatch_key_set.excluded()
|
|
// GradMode is disabled.
|
|
// NormalMode: ADInplaceOrView in raw_local_dispatch_key_set.included(),
|
|
// Autograd not in raw_local_dispatch_key_set.excluded()
|
|
// GradMode is enabled by default unless toggled manually
|
|
// through other APIs, e.g. NoGradGuard.
|
|
//
|
|
// Invariant:
|
|
// - ADInplaceOrView is never in the excluded set
|
|
// - Autograd is never in the included set
|
|
// - Setting InferenceMode will set GradMode accordingly, but not vice versa.
|
|
//
|
|
// 1. Why do we put ADInplaceOrView in included set outside InferenceMode?
|
|
//
|
|
// Inplace update to inference tensor outside InferenceMode is not
|
|
// allowed. See Note [Inplace update inference tensor] for more details.
|
|
// Without going through ADInplaceOrView kernel, we cannot throw error
|
|
// for `inference_tensor.add_(1)` case.
|
|
//
|
|
// 2. Why not put ADInplaceOrView in the excluded set inside InferenceMode?
|
|
//
|
|
// For example:
|
|
// torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true);
|
|
// torch::Tensor k = a + 2;
|
|
// {
|
|
// c10::InferenceMode guard(true);
|
|
// k.add_(2);
|
|
// }
|
|
// `k.add_(2)` still need to go through ADInplaceOrView kernel so that it's
|
|
// prepared for future autograd.
|
|
//
|
|
// 3. Why does setting InferenceMode also set GradMode?
|
|
//
|
|
// This is required since InferenceMode is a faster and more restrictive
|
|
// version of NoGradGuard. All runtime checks using GradMode::is_enabled()
|
|
// are applicable to InferenceMode as well, e.g.
|
|
// `tensorTypeInCurrentExecutionContext` in interpreter.cpp.
|
|
InferenceMode(bool enabled = true)
|
|
: prev_mode(AutogradState::get_tls_state()),
|
|
prev_keyset(c10::impl::tls_local_dispatch_key_set()) {
|
|
// Enabling inference mode means disabling grad modes
|
|
// And disabling inference mode means enabling grad modes
|
|
AutogradState::set_tls_state(AutogradState(
|
|
/* grad_mode */ !enabled,
|
|
/* inference_mode */ enabled,
|
|
/* fw_grad_mode */ !enabled,
|
|
/* multithreading_enabled*/ !enabled));
|
|
DispatchKeySet included = enabled
|
|
? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
|
|
: prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);
|
|
DispatchKeySet excluded = enabled
|
|
? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset)
|
|
: (prev_keyset.excluded_ - c10::autograd_dispatch_keyset);
|
|
c10::impl::PODLocalDispatchKeySet cur_keyset{};
|
|
cur_keyset.set_included(included);
|
|
cur_keyset.set_excluded(excluded);
|
|
c10::impl::_force_tls_local_dispatch_key_set(cur_keyset);
|
|
}
|
|
|
|
~InferenceMode() {
|
|
AutogradState::set_tls_state(prev_mode);
|
|
c10::impl::_force_tls_local_dispatch_key_set(prev_keyset);
|
|
}
|
|
static bool is_enabled();
|
|
|
|
private:
|
|
AutogradState prev_mode;
|
|
c10::impl::LocalDispatchKeySet prev_keyset;
|
|
};
|
|
} // namespace c10
|