diff --git a/torch/csrc/jit/passes/autocast.cpp b/torch/csrc/jit/passes/autocast.cpp index 1d5cb636e45..4699cceec5b 100644 --- a/torch/csrc/jit/passes/autocast.cpp +++ b/torch/csrc/jit/passes/autocast.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -17,7 +18,7 @@ namespace torch::jit { namespace { -bool autocast_enabled = true; +std::atomic autocast_enabled = true; struct AutocastContext { bool gpu_enabled = false; @@ -509,9 +510,7 @@ void handleBlock(Block* block, AutocastContext initial_state) { } // namespace bool setAutocastMode(bool value) { - auto old_value = autocast_enabled; - autocast_enabled = value; - return old_value; + return autocast_enabled.exchange(value); } bool autocastEnabled() {