Replace c10::call_once with static initialization (#166381)

This PR replaces c10::call_once calls with static initialization when possible. C++11 semantics guarantees that static initialization is atomic. Static initialization also has lower cost than using c10::call_once.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166381
Approved by: https://github.com/malfet
This commit is contained in:
Yuanyuan Chen 2025-11-01 07:09:40 +00:00 committed by PyTorch MergeBot
parent 4316df857c
commit f0745ddb11
6 changed files with 30 additions and 35 deletions

View File

@ -2,8 +2,6 @@
#include <ATen/Tensor.h> #include <ATen/Tensor.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <mutex>
namespace at { namespace at {
namespace cuda { namespace cuda {
namespace detail { namespace detail {
@ -12,39 +10,36 @@ __device__ __constant__ float cublas_one_device;
__device__ __constant__ float cublas_zero_device; __device__ __constant__ float cublas_zero_device;
float *get_cublas_device_one() { float *get_cublas_device_one() {
static c10::once_flag init_flag; static float *ptr = nullptr;
static auto init_flag = [&]() {
c10::call_once(init_flag, []() {
const float one = 1.f; const float one = 1.f;
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float))); AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
});
float *ptr;
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device)); AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
return true;
}();
return ptr; return ptr;
} }
float *get_cublas_device_zero() { float *get_cublas_device_zero() {
static c10::once_flag init_flag; static float *ptr = nullptr;
static auto init_flag = [&]() {
c10::call_once(init_flag, []() {
const float zero = 0.f; const float zero = 0.f;
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float))); AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
});
float *ptr;
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device)); AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
return true;
}();
return ptr; return ptr;
} }
float *get_user_alpha_ptr() { float *get_user_alpha_ptr() {
static float *alpha_ptr; static float *alpha_ptr;
static c10::once_flag init_flag; static bool init_flag [[maybe_unused]] = []() {
c10::call_once(init_flag, []() {
AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float))); AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
}); return true;
}();
return alpha_ptr; return alpha_ptr;
} }

View File

@ -1,7 +1,5 @@
// Copyright © 2022 Apple Inc. // Copyright © 2022 Apple Inc.
#include <c10/util/CallOnce.h>
#include <ATen/mps/IndexKernels.h> #include <ATen/mps/IndexKernels.h>
#include <ATen/mps/MPSAllocatorInterface.h> #include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/MPSDevice.h> #include <ATen/mps/MPSDevice.h>
@ -10,9 +8,6 @@
namespace at::mps { namespace at::mps {
static std::unique_ptr<MPSDevice> mps_device;
static c10::once_flag mpsdev_init;
static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) { static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) {
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants) // MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
// host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+ // host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+
@ -21,8 +16,8 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
} }
MPSDevice* MPSDevice::getInstance() { MPSDevice* MPSDevice::getInstance() {
c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr<MPSDevice>(new MPSDevice()); }); static MPSDevice mps_device;
return mps_device.get(); return &mps_device;
} }
MPSDevice::~MPSDevice() { MPSDevice::~MPSDevice() {

View File

@ -15,7 +15,6 @@ namespace c10::cuda {
namespace { namespace {
// Global stream state and constants // Global stream state and constants
c10::once_flag init_flag;
DeviceIndex num_gpus = -1; DeviceIndex num_gpus = -1;
constexpr int kStreamsPerPoolBits = 5; constexpr int kStreamsPerPoolBits = 5;
constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits;
@ -226,7 +225,10 @@ void initDeviceStreamState(DeviceIndex device_index) {
// Init front-end to ensure initialization only occurs once // Init front-end to ensure initialization only occurs once
void initCUDAStreamsOnce() { void initCUDAStreamsOnce() {
// Inits default streams (once, globally) // Inits default streams (once, globally)
c10::call_once(init_flag, initGlobalStreamState); auto static init_flag [[maybe_unused]] = [] {
initGlobalStreamState();
return true;
}();
if (current_streams) { if (current_streams) {
return; return;

View File

@ -1,4 +1,3 @@
#include <c10/util/CallOnce.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/xpu/XPUFunctions.h> #include <c10/xpu/XPUFunctions.h>
@ -33,7 +32,6 @@ namespace {
* one iGPU and enumerate all iGPUs on that platform. * one iGPU and enumerate all iGPUs on that platform.
* 3. If neither dGPUs nor iGPUs are found, conclude that no GPUs are available. * 3. If neither dGPUs nor iGPUs are found, conclude that no GPUs are available.
*/ */
c10::once_flag init_flag;
thread_local DeviceIndex curDeviceIndex = 0; thread_local DeviceIndex curDeviceIndex = 0;
struct DevicePool { struct DevicePool {
@ -149,7 +147,10 @@ inline void initGlobalDevicePoolState() {
} }
inline void initDevicePoolCallOnce() { inline void initDevicePoolCallOnce() {
c10::call_once(init_flag, initGlobalDevicePoolState); auto static init_flag [[maybe_unused]] = [] {
initGlobalDevicePoolState();
return true;
}();
} }
void initDeviceProperties(DeviceProp* device_prop, DeviceIndex device) { void initDeviceProperties(DeviceProp* device_prop, DeviceIndex device) {

View File

@ -12,7 +12,6 @@ namespace c10::xpu {
namespace { namespace {
// Global stream state and constants // Global stream state and constants
c10::once_flag init_flag;
DeviceIndex num_gpus = -1; DeviceIndex num_gpus = -1;
constexpr int kStreamsPerPoolBits = 5; constexpr int kStreamsPerPoolBits = 5;
constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits;
@ -163,7 +162,10 @@ void initDeviceStreamState(DeviceIndex device) {
} }
void initXPUStreamsOnce() { void initXPUStreamsOnce() {
c10::call_once(init_flag, initGlobalStreamState); auto static init_flag [[maybe_unused]] = [] {
initGlobalStreamState();
return true;
}();
if (current_streams) { if (current_streams) {
return; return;

View File

@ -349,8 +349,7 @@ static void cacheAllocatorDeregisterHook(
} }
static void attachAllocatorHooks() { static void attachAllocatorHooks() {
static c10::once_flag flag; static auto flag [[maybe_unused]] = [] {
c10::call_once(flag, [] {
// Attaching hooks fails if CUDACachingAllocator is not initialized, so // Attaching hooks fails if CUDACachingAllocator is not initialized, so
// Init for CUDA is called (and is a no-op if CUDA is already // Init for CUDA is called (and is a no-op if CUDA is already
// initialized). // initialized).
@ -359,7 +358,8 @@ static void attachAllocatorHooks() {
&cacheAllocatorRegisterHook); &cacheAllocatorRegisterHook);
c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
&cacheAllocatorDeregisterHook); &cacheAllocatorDeregisterHook);
}); return true;
}();
} }
static std:: static std::