#pragma once #include #include namespace c10 { struct C10_API GradMode { static bool is_enabled(); static void set_enabled(bool enabled); }; // A RAII, thread local (!) guard that enables or disables grad mode upon // construction, and sets it back to the original value upon destruction. struct C10_API AutoGradMode { AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) { GradMode::set_enabled(enabled); } AutoGradMode(const AutoGradMode&) = delete; AutoGradMode(AutoGradMode&&) = delete; AutoGradMode& operator=(const AutoGradMode&) = delete; AutoGradMode& operator=(AutoGradMode&&) = delete; ~AutoGradMode() { GradMode::set_enabled(prev_mode); } bool prev_mode; }; // A RAII, thread local (!) guard that stops future operations from building // gradients. struct C10_API NoGradGuard : public AutoGradMode { NoGradGuard() : AutoGradMode(/*enabled=*/false) {} }; // A RAII, thread local (!) guard that enables or disables forward grad mode // upon construction, and sets it back to the original value upon destruction. struct C10_API AutoFwGradMode { AutoFwGradMode(bool enabled) : prev_mode(AutogradState::get_tls_state().get_fw_grad_mode()) { AutogradState::get_tls_state().set_fw_grad_mode(enabled); } AutoFwGradMode(const AutoFwGradMode&) = delete; AutoFwGradMode(AutoFwGradMode&&) = delete; AutoFwGradMode& operator=(const AutoFwGradMode&) = delete; AutoFwGradMode& operator=(AutoFwGradMode&&) = delete; ~AutoFwGradMode() { AutogradState::get_tls_state().set_fw_grad_mode(prev_mode); } bool prev_mode; }; } // namespace c10