From e64a814ae75c67f35b7dfd78af373cd91c8b7b29 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Wed, 22 Oct 2025 21:38:52 +0000 Subject: [PATCH] [CUDA] Add experimental green context support for SM carveout (#159104) Low-level PyTorch APIs should be usable/stable enough at this point but we might move the underlying driver API usage a bit from here... Built on top of @drisspg 's branch Pull Request resolved: https://github.com/pytorch/pytorch/pull/159104 Approved by: https://github.com/ngimel, https://github.com/malfet, https://github.com/kwen2501 Co-authored-by: drisspg Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- aten/src/ATen/cuda/CUDAGreenContext.cpp | 192 ++++++++++++++++++++++++ aten/src/ATen/cuda/CUDAGreenContext.h | 53 +++++++ build_variables.bzl | 1 + c10/cuda/driver_api.h | 11 ++ caffe2/CMakeLists.txt | 6 + docs/source/cuda.md | 26 ++++ test/test_matmul_cuda.py | 26 ++++ test/test_public_bindings.py | 1 + torch/_C/__init__.pyi.in | 14 ++ torch/csrc/Module.cpp | 2 + torch/csrc/cuda/GreenContext.cpp | 13 ++ torch/cuda/__init__.py | 2 + torch/cuda/green_contexts.py | 42 ++++++ torch/testing/_internal/common_cuda.py | 12 ++ 14 files changed, 401 insertions(+) create mode 100644 aten/src/ATen/cuda/CUDAGreenContext.cpp create mode 100644 aten/src/ATen/cuda/CUDAGreenContext.h create mode 100644 torch/csrc/cuda/GreenContext.cpp create mode 100644 torch/cuda/green_contexts.py diff --git a/aten/src/ATen/cuda/CUDAGreenContext.cpp b/aten/src/ATen/cuda/CUDAGreenContext.cpp new file mode 100644 index 00000000000..6108f6e96a8 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAGreenContext.cpp @@ -0,0 +1,192 @@ +#include + +namespace at::cuda { + GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { +#if CUDA_HAS_GREEN_CONTEXT + int driver_version; + C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); + TORCH_CHECK( + driver_version >= 12080, "cuda driver too old to use green context!"); + CUcontext pctx = nullptr; + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx)); + if (C10_UNLIKELY(!pctx)) { + TORCH_WARN( + "Attempted to create a green context but" + " there was no primary context! Creating a primary context..."); + + cudaFree(0); + } + + CUdevice device; + device_id_ = device_id; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); + + // Get device resources + CUdevResource device_resource; + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( + device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); + + // Split resources + std::vector result(1); + auto result_data = result.data(); + unsigned int nb_groups = 1; + CUdevResource remaining; + + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( + result_data, + &nb_groups, + &device_resource, + &remaining, + 0, // default flags + num_sms)); + + TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); + + // Generate resource descriptor + CUdevResourceDesc desc; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( + &desc, result_data, 1)); + + // Create green context + // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( + &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM)); + + // Convert to regular context + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_)); + TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!"); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + std::unique_ptr GreenContext::create( + uint32_t num_sms, + std::optional device_id) { +#if CUDA_HAS_GREEN_CONTEXT + if (!device_id.has_value()) { + device_id = at::cuda::current_device(); + } + return std::make_unique(device_id.value(), num_sms); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + // Implement move operations + GreenContext::GreenContext(GreenContext&& other) noexcept{ +#if CUDA_HAS_GREEN_CONTEXT + device_id_ = std::exchange(other.device_id_, -1); + green_ctx_ = std::exchange(other.green_ctx_, nullptr); + context_ = std::exchange(other.context_, nullptr); + parent_stream_ = std::exchange(other.parent_stream_, nullptr); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{ +#if CUDA_HAS_GREEN_CONTEXT + if (this != &other) { + // Clean up current resources + if (green_ctx_) { + CUcontext current = nullptr; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); + if (current == context_) { + TORCH_CHECK( + false, + "attempting to overwrite current green ctx " + "when it is active!"); + } + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); + } + + // Take ownership of other's resources + device_id_ = std::exchange(other.device_id_, -1); + green_ctx_ = std::exchange(other.green_ctx_, nullptr); + context_ = std::exchange(other.context_, nullptr); + parent_stream_ = std::exchange(other.parent_stream_, nullptr); + } + return *this; +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + GreenContext::~GreenContext() noexcept{ +#if CUDA_HAS_GREEN_CONTEXT + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + // Get the underlying CUDA context + CUcontext GreenContext::getContext() const { +#if CUDA_HAS_GREEN_CONTEXT + return context_; +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + // Get the underlying green context +#if CUDA_HAS_GREEN_CONTEXT + CUgreenCtx GreenContext::getGreenContext() const { + return green_ctx_; + } +#endif + + // Make this context current + void GreenContext::setContext() { +#if CUDA_HAS_GREEN_CONTEXT + auto current_stream = c10::cuda::getCurrentCUDAStream(); + parent_stream_ = current_stream.stream(); + + at::cuda::CUDAEvent ev; + ev.record(current_stream); + + CUcontext current = nullptr; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); + if (!current) { + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_)); + } else { + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_)); + } + // currently hardcodes the new green context to use the default stream + // TODO(eqy): consider creating a new stream if e.g., it allows interop + // with CUDA Graph captures etc. + auto default_stream = c10::cuda::getDefaultCUDAStream(); + ev.block(default_stream); + c10::cuda::setCurrentCUDAStream(default_stream); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } + + void GreenContext::popContext() { +#if CUDA_HAS_GREEN_CONTEXT + // see above note about stream being hardcoded to the default stream + at::cuda::CUDAEvent ev; + ev.record(c10::cuda::getCurrentCUDAStream()); + CUcontext popped; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped)); + TORCH_INTERNAL_ASSERT( + popped == context_, "expected popped context to be the current ctx"); + ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_)); +#else + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); +#endif + } +} // namespace at::cuda diff --git a/aten/src/ATen/cuda/CUDAGreenContext.h b/aten/src/ATen/cuda/CUDAGreenContext.h new file mode 100644 index 00000000000..4f198e2e1c0 --- /dev/null +++ b/aten/src/ATen/cuda/CUDAGreenContext.h @@ -0,0 +1,53 @@ +#pragma once +#include + +#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#include +#include +#include +#include +#include +#define CUDA_HAS_GREEN_CONTEXT 1 +#else +#define CUDA_HAS_GREEN_CONTEXT 0 +#endif + +namespace at::cuda { + +class TORCH_CUDA_CPP_API GreenContext { + public: + GreenContext(uint32_t device_id, uint32_t num_sms); + + static std::unique_ptr create(uint32_t num_sms, std::optional device_id); + + // Delete copy constructor and assignment + GreenContext(const GreenContext&) = delete; + GreenContext& operator=(const GreenContext&) = delete; + + // Implement move operations + GreenContext(GreenContext&& other) noexcept; + GreenContext& operator=(GreenContext&& other) noexcept; + ~GreenContext() noexcept; + + // Get the underlying CUDA context + CUcontext getContext() const; + + // Get the underlying green context +#if CUDA_HAS_GREEN_CONTEXT + CUgreenCtx getGreenContext() const; +#endif + + // Make this context current + void setContext(); + + void popContext(); + + private: +#if CUDA_HAS_GREEN_CONTEXT + int32_t device_id_ = -1; + CUgreenCtx green_ctx_ = nullptr; + CUcontext context_ = nullptr; + cudaStream_t parent_stream_ = nullptr; +#endif +}; +} // namespace at::cuda diff --git a/build_variables.bzl b/build_variables.bzl index ce1c5f1c97b..338e49777bf 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -855,6 +855,7 @@ libtorch_python_cuda_core_sources = [ "torch/csrc/cuda/Stream.cpp", "torch/csrc/cuda/Graph.cpp", "torch/csrc/cuda/MemPool.cpp", + "torch/csrc/cuda/GreenContext.cpp", "torch/csrc/cuda/shared/cudart.cpp", "torch/csrc/cuda/shared/nvtx.cpp", "torch/csrc/cuda/utils.cpp", diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 8910e581a1a..380e7939ff7 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -51,6 +51,17 @@ #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \ + _(cuCtxFromGreenCtx, 12080) \ + _(cuCtxGetCurrent, 12080) \ + _(cuCtxPopCurrent, 12080) \ + _(cuCtxPushCurrent, 12080) \ + _(cuCtxSetCurrent, 12080) \ + _(cuGreenCtxCreate, 12080) \ + _(cuGreenCtxDestroy, 12080) \ + _(cuDevSmResourceSplitByCount, 12080) \ + _(cuDeviceGet, 12080) \ + _(cuDeviceGetDevResource, 12080) \ + _(cuDevResourceGenerateDesc, 12080) \ _(cuMulticastAddDevice, 12030) \ _(cuMulticastBindMem, 12030) \ _(cuMulticastCreate, 12030) \ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 7ab54dfa86a..8eef838bb2a 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -607,6 +607,12 @@ if(USE_CUDA) set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") endif() endif() + if(NOT WIN32) + set_source_files_properties( + ${TORCH_ROOT}/aten/src/ATen/cuda/CUDAGreenContext.cpp + PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" + ) + endif() set_source_files_properties( ${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp PROPERTIES COMPILE_DEFINITIONS "NVRTC_SHORTHASH=${CUDA_NVRTC_SHORTHASH}" diff --git a/docs/source/cuda.md b/docs/source/cuda.md index bd752ad684b..94894942b74 100644 --- a/docs/source/cuda.md +++ b/docs/source/cuda.md @@ -258,6 +258,28 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t ``` +## Green Contexts (experimental) + +`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs +to enable more general carveout of SM resources for CUDA kernels. + +These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8. + +See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these. + +```{eval-rst} +.. currentmodule:: torch.cuda.green_contexts +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + GreenContext +``` + + % This module needs to be documented. Adding here in the meantime % for tracking purposes @@ -270,6 +292,10 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t .. py:module:: torch.cuda.gds ``` +```{eval-rst} +.. py:module:: torch.cuda.green_contexts +``` + ```{eval-rst} .. py:module:: torch.cuda.jiterator ``` diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index bf46ee0709f..df3052b2475 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -1,6 +1,7 @@ # Owner(s): ["module: linear algebra"] import contextlib +import time import unittest from itertools import product from functools import partial @@ -15,6 +16,7 @@ from torch.quantization._quantized_conversions import ( from torch.testing import make_tensor from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_BF16, + PLATFORM_SUPPORTS_GREEN_CONTEXT, SM53OrLater, SM80OrLater, SM90OrLater, @@ -40,6 +42,7 @@ from torch.testing._internal.common_utils import ( parametrize, run_tests, runOnRocmArch, + serialTest, skipIfRocm, TEST_CUDA, TEST_WITH_ROCM, @@ -855,6 +858,29 @@ class TestMatmulCuda(InductorTestCase): op(a, mismatch_batch_dim_b, out_dtype=torch.float32) + @unittest.skipIf(not PLATFORM_SUPPORTS_GREEN_CONTEXT, "Green contexts are not supported") + @serialTest() + def test_greencontext_carveout(self): + a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16) + ctx = torch.cuda.green_contexts.GreenContext.create(1, 0) + ctx.set_context() + torch.matmul(a, a) + torch.cuda.synchronize() + t0 = time.perf_counter() + partial_res = torch.matmul(a, a) + torch.cuda.synchronize() + t1 = time.perf_counter() + ctx.pop_context() + torch.matmul(a, a) + torch.cuda.synchronize() + t2 = time.perf_counter() + full_res = torch.matmul(a, a) + torch.cuda.synchronize() + t3 = time.perf_counter() + self.assertEqual(partial_res, full_res) + self.assertGreater(t1 - t0, t3 - t2) + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") @unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index fa21705c76e..dff4e9c014c 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -123,6 +123,7 @@ class TestPublicBindings(TestCase): "FutureType", "Generator", "GeneratorType", + "GreenContext", "get_autocast_cpu_dtype", "get_autocast_dtype", "get_autocast_ipu_dtype", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 7194f4bccb3..8eb1bed936e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2792,3 +2792,17 @@ class _StaticCudaLauncher: args: tuple[Any, ...], stream: _int, ) -> None: ... + +# Defined in torch/csrc/cuda/GreenContext.cpp +class GreenContext: + @staticmethod + def create( + num_sms: _int, + device_id: _int, + ) -> GreenContext: ... + def set_context( + self, + ) -> None: ... + def pop_context( + self, + ) -> None: ... diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 5e5dc2fd6b7..1aaa3cf8bac 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1957,6 +1957,7 @@ void THCPStream_init(PyObject* module); void THCPEvent_init(PyObject* module); void THCPGraph_init(PyObject* module); void THCPMemPool_init(PyObject* module); +void THCPGreenContext_init(PyObject* module); PyMethodDef* THCPModule_methods(); namespace torch::cuda { void initModule(PyObject* module); @@ -2184,6 +2185,7 @@ PyObject* initModule() { THCPEvent_init(module); THCPGraph_init(module); THCPMemPool_init(module); + THCPGreenContext_init(module); #endif #ifdef USE_XPU diff --git a/torch/csrc/cuda/GreenContext.cpp b/torch/csrc/cuda/GreenContext.cpp new file mode 100644 index 00000000000..9fae9843779 --- /dev/null +++ b/torch/csrc/cuda/GreenContext.cpp @@ -0,0 +1,13 @@ +#include +#include +#include + +// Cargo culted partially from csrc/cuda/Stream.cpp + +void THCPGreenContext_init(PyObject* module) { + auto m = py::handle(module).cast(); + py::class_(m, "_CUDAGreenContext") + .def_static("create", &::at::cuda::GreenContext::create) + .def("set_context", &::at::cuda::GreenContext::setContext) + .def("pop_context", &::at::cuda::GreenContext::popContext); +} diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 21e5a19056c..bb4a1e29dae 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -35,6 +35,7 @@ from .graphs import ( is_current_stream_capturing, make_graphed_callables, ) +from .green_contexts import GreenContext from .streams import Event, ExternalStream, Stream @@ -1844,6 +1845,7 @@ __all__ = [ "ExternalStream", "Stream", "StreamContext", + "GreenContext", "amp", "caching_allocator_alloc", "caching_allocator_delete", diff --git a/torch/cuda/green_contexts.py b/torch/cuda/green_contexts.py new file mode 100644 index 00000000000..078cd06e19c --- /dev/null +++ b/torch/cuda/green_contexts.py @@ -0,0 +1,42 @@ +import torch + + +_GreenContext = object +SUPPORTED = False + +if hasattr(torch._C, "_CUDAGreenContext"): + _GreenContext = torch._C._CUDAGreenContext # type: ignore[misc] + SUPPORTED = True + + +# Python shim helps Sphinx process docstrings more reliably. +class GreenContext(_GreenContext): + r"""Wrapper around a CUDA green context. + + .. warning:: + This API is in beta and may change in future releases. + """ + + @staticmethod + def create(num_sms: int, device_id: int = 0) -> _GreenContext: + r"""Create a CUDA green context. + + Arguments: + num_sms (int): The number of SMs to use in the green context. + device_id (int, optional): The device index of green context. + """ + if not SUPPORTED: + raise RuntimeError("PyTorch was not built with Green Context support!") + return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined] + + # Note that these functions are bypassed by we define them here + # for Sphinx documentation purposes + def set_context(self) -> None: + r"""Make the green context the current context.""" + return super().set_context() # type: ignore[misc] + + def pop_context(self) -> None: + r"""Assuming the green context is the current context, pop it from the + context stack and restore the previous context. + """ + return super().pop_context() # type: ignore[misc] diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 8202a32ae8a..74dfe0c56c2 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -81,6 +81,16 @@ def evaluate_platform_supports_efficient_attention(): def evaluate_platform_supports_cudnn_attention(): return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000) +def evaluate_platform_supports_green_context(): + if IS_WINDOWS: + return False + if not _get_torch_cuda_version() >= (12, 8): + return False + driver_version = torch.utils.collect_env.get_nvidia_driver_version(torch.utils.collect_env.run) + if driver_version is None: + return False + return int(driver_version.split('.')[0]) >= 550 + PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention()) PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention()) PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention()) @@ -93,6 +103,8 @@ PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater) +PLATFORM_SUPPORTS_GREEN_CONTEXT: bool = LazyVal(lambda: evaluate_platform_supports_green_context()) + def evaluate_platform_supports_fp8(): if torch.cuda.is_available(): if torch.version.hip: