pytorch/c10/core/GradMode.cpp
albanD 83d9bad44a Add a common autograd TLS state (#63114)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63114

This PR collapses the GradMode and InferenceMode thread local booleans into a single thread local uint8.
This helps reducing the number of thread local variable accesses done when we propagate ThreadLocalStates.

Note that this is even more beneficial as we will add a forward mode AD TLS (similar to GradMode) higher in this stack and this new structure should reduce the perf impact of adding this new TLS.

Here is the full benchmark result between master and the top of this stack: https://gist.github.com/albanD/e421101e9ed344e94999bef3a54bf0f3
tl;dr: give a benefit in most cases. It is never detrimental.

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D30388099

Pulled By: albanD

fbshipit-source-id: 8e03f940150ff063c2edd792733663413ae2f486
2021-08-24 06:54:02 -07:00

16 lines
318 B
C++

#include <c10/core/AutogradState.h>
#include <c10/core/GradMode.h>
#include <stdexcept>
namespace c10 {
bool GradMode::is_enabled() {
return AutogradState::get_tls_state().get_grad_mode();
}
void GradMode::set_enabled(bool enabled) {
AutogradState::get_tls_state().set_grad_mode(enabled);
}
} // namespace c10