mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63114 This PR collapses the GradMode and InferenceMode thread local booleans into a single thread local uint8. This helps reducing the number of thread local variable accesses done when we propagate ThreadLocalStates. Note that this is even more beneficial as we will add a forward mode AD TLS (similar to GradMode) higher in this stack and this new structure should reduce the perf impact of adding this new TLS. Here is the full benchmark result between master and the top of this stack: https://gist.github.com/albanD/e421101e9ed344e94999bef3a54bf0f3 tl;dr: give a benefit in most cases. It is never detrimental. Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D30388099 Pulled By: albanD fbshipit-source-id: 8e03f940150ff063c2edd792733663413ae2f486
16 lines
318 B
C++
16 lines
318 B
C++
#include <c10/core/AutogradState.h>
|
|
#include <c10/core/GradMode.h>
|
|
|
|
#include <stdexcept>
|
|
|
|
namespace c10 {
|
|
|
|
bool GradMode::is_enabled() {
|
|
return AutogradState::get_tls_state().get_grad_mode();
|
|
}
|
|
|
|
void GradMode::set_enabled(bool enabled) {
|
|
AutogradState::get_tls_state().set_grad_mode(enabled);
|
|
}
|
|
} // namespace c10
|