From cf9a834f39d6c3b769381cc1321bcb5c54132415 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 31 Oct 2025 20:10:59 +0000 Subject: [PATCH] [BE] Move GreenContext implementation details to cpp (#166462) - Remove all complex defines logic from the header - Make GreenContext constructor private, as it should only be created via the static method as singleton - Delete unused `getContext` and `getGreenContext` methods - Rename `CUDA_HAS_GREEN_CONTEXT` to `HAS_CUDA_GREEN_CONTEXT()`, which results in compilation error if one accidentally makes a typo - Suppress `-Wunused-private-field` is GreenContext is not available Pull Request resolved: https://github.com/pytorch/pytorch/pull/166462 Approved by: https://github.com/ngimel, https://github.com/eqy --- aten/src/ATen/cuda/CUDAGreenContext.cpp | 164 ++++++++++++------------ aten/src/ATen/cuda/CUDAGreenContext.h | 41 ++---- 2 files changed, 93 insertions(+), 112 deletions(-) diff --git a/aten/src/ATen/cuda/CUDAGreenContext.cpp b/aten/src/ATen/cuda/CUDAGreenContext.cpp index 6108f6e96a8..7e6e17e3df6 100644 --- a/aten/src/ATen/cuda/CUDAGreenContext.cpp +++ b/aten/src/ATen/cuda/CUDAGreenContext.cpp @@ -1,78 +1,90 @@ #include -namespace at::cuda { - GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { -#if CUDA_HAS_GREEN_CONTEXT - int driver_version; - C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); - TORCH_CHECK( - driver_version >= 12080, "cuda driver too old to use green context!"); - CUcontext pctx = nullptr; - C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx)); - if (C10_UNLIKELY(!pctx)) { - TORCH_WARN( - "Attempted to create a green context but" - " there was no primary context! Creating a primary context..."); - - cudaFree(0); - } - - CUdevice device; - device_id_ = device_id; - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); - - // Get device resources - CUdevResource device_resource; - C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( - device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); - - // Split resources - std::vector result(1); - auto result_data = result.data(); - unsigned int nb_groups = 1; - CUdevResource remaining; - - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( - result_data, - &nb_groups, - &device_resource, - &remaining, - 0, // default flags - num_sms)); - - TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); - - // Generate resource descriptor - CUdevResourceDesc desc; - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( - &desc, result_data, 1)); - - // Create green context - // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html - C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( - &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM)); - - // Convert to regular context - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_)); - TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!"); +#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#include +#include +#include +#define HAS_CUDA_GREEN_CONTEXT() 1 #else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#define HAS_CUDA_GREEN_CONTEXT() 0 +// Suppress unsued private field warnings as this class is not supposed to be called +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-private-field") +#endif + +namespace at::cuda { + +GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { +#if HAS_CUDA_GREEN_CONTEXT() + int driver_version; + C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); + TORCH_CHECK( + driver_version >= 12080, "cuda driver too old to use green context!"); + CUcontext pctx = nullptr; + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx)); + if (C10_UNLIKELY(!pctx)) { + TORCH_WARN( + "Attempted to create a green context but" + " there was no primary context! Creating a primary context..."); + + cudaFree(0); + } + + CUdevice device; + device_id_ = device_id; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); + + // Get device resources + CUdevResource device_resource; + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( + device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); + + // Split resources + std::vector result(1); + auto result_data = result.data(); + unsigned int nb_groups = 1; + CUdevResource remaining; + + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( + result_data, + &nb_groups, + &device_resource, + &remaining, + 0, // default flags + num_sms)); + + TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); + + // Generate resource descriptor + CUdevResourceDesc desc; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( + &desc, result_data, 1)); + + // Create green context + // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( + &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM)); + + // Convert to regular context + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_)); + TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!"); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); #endif } std::unique_ptr GreenContext::create( uint32_t num_sms, std::optional device_id) { -#if CUDA_HAS_GREEN_CONTEXT +#if HAS_CUDA_GREEN_CONTEXT() if (!device_id.has_value()) { device_id = at::cuda::current_device(); } - return std::make_unique(device_id.value(), num_sms); + return std::unique_ptr(new GreenContext(device_id.value(), num_sms)); #else TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); #endif @@ -80,7 +92,7 @@ namespace at::cuda { // Implement move operations GreenContext::GreenContext(GreenContext&& other) noexcept{ -#if CUDA_HAS_GREEN_CONTEXT +#if HAS_CUDA_GREEN_CONTEXT() device_id_ = std::exchange(other.device_id_, -1); green_ctx_ = std::exchange(other.green_ctx_, nullptr); context_ = std::exchange(other.context_, nullptr); @@ -91,7 +103,7 @@ namespace at::cuda { } GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{ -#if CUDA_HAS_GREEN_CONTEXT +#if HAS_CUDA_GREEN_CONTEXT() if (this != &other) { // Clean up current resources if (green_ctx_) { @@ -120,7 +132,7 @@ namespace at::cuda { } GreenContext::~GreenContext() noexcept{ -#if CUDA_HAS_GREEN_CONTEXT +#if HAS_CUDA_GREEN_CONTEXT() C10_CUDA_DRIVER_CHECK( c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); #else @@ -128,25 +140,9 @@ namespace at::cuda { #endif } - // Get the underlying CUDA context - CUcontext GreenContext::getContext() const { -#if CUDA_HAS_GREEN_CONTEXT - return context_; -#else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); -#endif - } - - // Get the underlying green context -#if CUDA_HAS_GREEN_CONTEXT - CUgreenCtx GreenContext::getGreenContext() const { - return green_ctx_; - } -#endif - // Make this context current void GreenContext::setContext() { -#if CUDA_HAS_GREEN_CONTEXT +#if HAS_CUDA_GREEN_CONTEXT() auto current_stream = c10::cuda::getCurrentCUDAStream(); parent_stream_ = current_stream.stream(); @@ -175,7 +171,7 @@ namespace at::cuda { } void GreenContext::popContext() { -#if CUDA_HAS_GREEN_CONTEXT +#if HAS_CUDA_GREEN_CONTEXT() // see above note about stream being hardcoded to the default stream at::cuda::CUDAEvent ev; ev.record(c10::cuda::getCurrentCUDAStream()); diff --git a/aten/src/ATen/cuda/CUDAGreenContext.h b/aten/src/ATen/cuda/CUDAGreenContext.h index 4f198e2e1c0..f9fa2cd112e 100644 --- a/aten/src/ATen/cuda/CUDAGreenContext.h +++ b/aten/src/ATen/cuda/CUDAGreenContext.h @@ -1,53 +1,38 @@ #pragma once #include - -#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) -#include #include -#include -#include -#include -#define CUDA_HAS_GREEN_CONTEXT 1 -#else -#define CUDA_HAS_GREEN_CONTEXT 0 -#endif + +// Forward declare green context as opaque ptr +typedef struct CUgreenCtx_st* CUgreenCtx; namespace at::cuda { class TORCH_CUDA_CPP_API GreenContext { public: - GreenContext(uint32_t device_id, uint32_t num_sms); - - static std::unique_ptr create(uint32_t num_sms, std::optional device_id); + // Green context creation + static std::unique_ptr create( + uint32_t num_sms, + std::optional device_id); + ~GreenContext() noexcept; // Delete copy constructor and assignment GreenContext(const GreenContext&) = delete; GreenContext& operator=(const GreenContext&) = delete; - // Implement move operations - GreenContext(GreenContext&& other) noexcept; - GreenContext& operator=(GreenContext&& other) noexcept; - ~GreenContext() noexcept; - - // Get the underlying CUDA context - CUcontext getContext() const; - - // Get the underlying green context -#if CUDA_HAS_GREEN_CONTEXT - CUgreenCtx getGreenContext() const; -#endif - // Make this context current void setContext(); void popContext(); private: -#if CUDA_HAS_GREEN_CONTEXT + GreenContext(uint32_t device_id, uint32_t num_sms); + // Implement move operations + GreenContext(GreenContext&& other) noexcept; + GreenContext& operator=(GreenContext&& other) noexcept; + int32_t device_id_ = -1; CUgreenCtx green_ctx_ = nullptr; CUcontext context_ = nullptr; cudaStream_t parent_stream_ = nullptr; -#endif }; } // namespace at::cuda