[torch] Fix unsafe concurrent access to autocast_enabled (#148281)

Summary: Making autocast_enabled atomic, as it can be accessed from multiple threads

Differential Revision: D70456813

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148281
Approved by: https://github.com/davidberard98
This commit is contained in:
Ivan Grigorev 2025-03-25 14:46:09 +00:00 committed by PyTorch MergeBot
parent a2bba53f87
commit d90d83c484

View File

@ -7,6 +7,7 @@
#include <torch/csrc/jit/ir/ir.h> #include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h> #include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/quantization/helper.h> #include <torch/csrc/jit/passes/quantization/helper.h>
#include <atomic>
#include <optional> #include <optional>
#include <stack> #include <stack>
@ -17,7 +18,7 @@ namespace torch::jit {
namespace { namespace {
bool autocast_enabled = true; std::atomic<bool> autocast_enabled = true;
struct AutocastContext { struct AutocastContext {
bool gpu_enabled = false; bool gpu_enabled = false;
@ -509,9 +510,7 @@ void handleBlock(Block* block, AutocastContext initial_state) {
} // namespace } // namespace
bool setAutocastMode(bool value) { bool setAutocastMode(bool value) {
auto old_value = autocast_enabled; return autocast_enabled.exchange(value);
autocast_enabled = value;
return old_value;
} }
bool autocastEnabled() { bool autocastEnabled() {