[CUDA] Add experimental green context support for SM carveout (#159104)

Low-level PyTorch APIs should be usable/stable enough at this point but we might move the underlying driver API usage a bit from here...

Built on top of @drisspg 's branch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159104
Approved by: https://github.com/ngimel, https://github.com/malfet, https://github.com/kwen2501

Co-authored-by: drisspg <drisspguessous@gmail.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Eddie Yan 2025-10-22 21:38:52 +00:00 committed by PyTorch MergeBot
parent 0b58d87aec
commit e64a814ae7
14 changed files with 401 additions and 0 deletions

View File

@ -0,0 +1,192 @@
#include <ATen/cuda/CUDAGreenContext.h>
namespace at::cuda {
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
#if CUDA_HAS_GREEN_CONTEXT
int driver_version;
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
TORCH_CHECK(
driver_version >= 12080, "cuda driver too old to use green context!");
CUcontext pctx = nullptr;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
if (C10_UNLIKELY(!pctx)) {
TORCH_WARN(
"Attempted to create a green context but"
" there was no primary context! Creating a primary context...");
cudaFree(0);
}
CUdevice device;
device_id_ = device_id;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
// Get device resources
CUdevResource device_resource;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
// Split resources
std::vector<CUdevResource> result(1);
auto result_data = result.data();
unsigned int nb_groups = 1;
CUdevResource remaining;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
result_data,
&nb_groups,
&device_resource,
&remaining,
0, // default flags
num_sms));
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
// Generate resource descriptor
CUdevResourceDesc desc;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
&desc, result_data, 1));
// Create green context
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
// Convert to regular context
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
std::unique_ptr<GreenContext> GreenContext::create(
uint32_t num_sms,
std::optional<uint32_t> device_id) {
#if CUDA_HAS_GREEN_CONTEXT
if (!device_id.has_value()) {
device_id = at::cuda::current_device();
}
return std::make_unique<GreenContext>(device_id.value(), num_sms);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Implement move operations
GreenContext::GreenContext(GreenContext&& other) noexcept{
#if CUDA_HAS_GREEN_CONTEXT
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
#if CUDA_HAS_GREEN_CONTEXT
if (this != &other) {
// Clean up current resources
if (green_ctx_) {
CUcontext current = nullptr;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&current));
if (current == context_) {
TORCH_CHECK(
false,
"attempting to overwrite current green ctx "
"when it is active!");
}
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
}
// Take ownership of other's resources
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
}
return *this;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
GreenContext::~GreenContext() noexcept{
#if CUDA_HAS_GREEN_CONTEXT
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying CUDA context
CUcontext GreenContext::getContext() const {
#if CUDA_HAS_GREEN_CONTEXT
return context_;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx GreenContext::getGreenContext() const {
return green_ctx_;
}
#endif
// Make this context current
void GreenContext::setContext() {
#if CUDA_HAS_GREEN_CONTEXT
auto current_stream = c10::cuda::getCurrentCUDAStream();
parent_stream_ = current_stream.stream();
at::cuda::CUDAEvent ev;
ev.record(current_stream);
CUcontext current = nullptr;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&current));
if (!current) {
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
} else {
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
}
// currently hardcodes the new green context to use the default stream
// TODO(eqy): consider creating a new stream if e.g., it allows interop
// with CUDA Graph captures etc.
auto default_stream = c10::cuda::getDefaultCUDAStream();
ev.block(default_stream);
c10::cuda::setCurrentCUDAStream(default_stream);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
void GreenContext::popContext() {
#if CUDA_HAS_GREEN_CONTEXT
// see above note about stream being hardcoded to the default stream
at::cuda::CUDAEvent ev;
ev.record(c10::cuda::getCurrentCUDAStream());
CUcontext popped;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
TORCH_INTERNAL_ASSERT(
popped == context_, "expected popped context to be the current ctx");
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
} // namespace at::cuda

View File

@ -0,0 +1,53 @@
#pragma once
#include <ATen/cuda/CUDAEvent.h>
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <cuda.h>
#include <memory>
#include <stdexcept>
#include <vector>
#define CUDA_HAS_GREEN_CONTEXT 1
#else
#define CUDA_HAS_GREEN_CONTEXT 0
#endif
namespace at::cuda {
class TORCH_CUDA_CPP_API GreenContext {
public:
GreenContext(uint32_t device_id, uint32_t num_sms);
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
// Delete copy constructor and assignment
GreenContext(const GreenContext&) = delete;
GreenContext& operator=(const GreenContext&) = delete;
// Implement move operations
GreenContext(GreenContext&& other) noexcept;
GreenContext& operator=(GreenContext&& other) noexcept;
~GreenContext() noexcept;
// Get the underlying CUDA context
CUcontext getContext() const;
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx getGreenContext() const;
#endif
// Make this context current
void setContext();
void popContext();
private:
#if CUDA_HAS_GREEN_CONTEXT
int32_t device_id_ = -1;
CUgreenCtx green_ctx_ = nullptr;
CUcontext context_ = nullptr;
cudaStream_t parent_stream_ = nullptr;
#endif
};
} // namespace at::cuda

View File

@ -855,6 +855,7 @@ libtorch_python_cuda_core_sources = [
"torch/csrc/cuda/Stream.cpp",
"torch/csrc/cuda/Graph.cpp",
"torch/csrc/cuda/MemPool.cpp",
"torch/csrc/cuda/GreenContext.cpp",
"torch/csrc/cuda/shared/cudart.cpp",
"torch/csrc/cuda/shared/nvtx.cpp",
"torch/csrc/cuda/utils.cpp",

View File

@ -51,6 +51,17 @@
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
_(cuCtxFromGreenCtx, 12080) \
_(cuCtxGetCurrent, 12080) \
_(cuCtxPopCurrent, 12080) \
_(cuCtxPushCurrent, 12080) \
_(cuCtxSetCurrent, 12080) \
_(cuGreenCtxCreate, 12080) \
_(cuGreenCtxDestroy, 12080) \
_(cuDevSmResourceSplitByCount, 12080) \
_(cuDeviceGet, 12080) \
_(cuDeviceGetDevResource, 12080) \
_(cuDevResourceGenerateDesc, 12080) \
_(cuMulticastAddDevice, 12030) \
_(cuMulticastBindMem, 12030) \
_(cuMulticastCreate, 12030) \

View File

@ -607,6 +607,12 @@ if(USE_CUDA)
set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a")
endif()
endif()
if(NOT WIN32)
set_source_files_properties(
${TORCH_ROOT}/aten/src/ATen/cuda/CUDAGreenContext.cpp
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
)
endif()
set_source_files_properties(
${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
PROPERTIES COMPILE_DEFINITIONS "NVRTC_SHORTHASH=${CUDA_NVRTC_SHORTHASH}"

View File

@ -258,6 +258,28 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
```
## Green Contexts (experimental)
`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs
to enable more general carveout of SM resources for CUDA kernels.
These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8.
See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these.
```{eval-rst}
.. currentmodule:: torch.cuda.green_contexts
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
GreenContext
```
% This module needs to be documented. Adding here in the meantime
% for tracking purposes
@ -270,6 +292,10 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
.. py:module:: torch.cuda.gds
```
```{eval-rst}
.. py:module:: torch.cuda.green_contexts
```
```{eval-rst}
.. py:module:: torch.cuda.jiterator
```

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: linear algebra"]
import contextlib
import time
import unittest
from itertools import product
from functools import partial
@ -15,6 +16,7 @@ from torch.quantization._quantized_conversions import (
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_BF16,
PLATFORM_SUPPORTS_GREEN_CONTEXT,
SM53OrLater,
SM80OrLater,
SM90OrLater,
@ -40,6 +42,7 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
runOnRocmArch,
serialTest,
skipIfRocm,
TEST_CUDA,
TEST_WITH_ROCM,
@ -855,6 +858,29 @@ class TestMatmulCuda(InductorTestCase):
op(a, mismatch_batch_dim_b, out_dtype=torch.float32)
@unittest.skipIf(not PLATFORM_SUPPORTS_GREEN_CONTEXT, "Green contexts are not supported")
@serialTest()
def test_greencontext_carveout(self):
a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16)
ctx = torch.cuda.green_contexts.GreenContext.create(1, 0)
ctx.set_context()
torch.matmul(a, a)
torch.cuda.synchronize()
t0 = time.perf_counter()
partial_res = torch.matmul(a, a)
torch.cuda.synchronize()
t1 = time.perf_counter()
ctx.pop_context()
torch.matmul(a, a)
torch.cuda.synchronize()
t2 = time.perf_counter()
full_res = torch.matmul(a, a)
torch.cuda.synchronize()
t3 = time.perf_counter()
self.assertEqual(partial_res, full_res)
self.assertGreater(t1 - t0, t3 - t2)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")

View File

@ -123,6 +123,7 @@ class TestPublicBindings(TestCase):
"FutureType",
"Generator",
"GeneratorType",
"GreenContext",
"get_autocast_cpu_dtype",
"get_autocast_dtype",
"get_autocast_ipu_dtype",

View File

@ -2792,3 +2792,17 @@ class _StaticCudaLauncher:
args: tuple[Any, ...],
stream: _int,
) -> None: ...
# Defined in torch/csrc/cuda/GreenContext.cpp
class GreenContext:
@staticmethod
def create(
num_sms: _int,
device_id: _int,
) -> GreenContext: ...
def set_context(
self,
) -> None: ...
def pop_context(
self,
) -> None: ...

View File

@ -1957,6 +1957,7 @@ void THCPStream_init(PyObject* module);
void THCPEvent_init(PyObject* module);
void THCPGraph_init(PyObject* module);
void THCPMemPool_init(PyObject* module);
void THCPGreenContext_init(PyObject* module);
PyMethodDef* THCPModule_methods();
namespace torch::cuda {
void initModule(PyObject* module);
@ -2184,6 +2185,7 @@ PyObject* initModule() {
THCPEvent_init(module);
THCPGraph_init(module);
THCPMemPool_init(module);
THCPGreenContext_init(module);
#endif
#ifdef USE_XPU

View File

@ -0,0 +1,13 @@
#include <ATen/cuda/CUDAGreenContext.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
// Cargo culted partially from csrc/cuda/Stream.cpp
void THCPGreenContext_init(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<at::cuda::GreenContext>(m, "_CUDAGreenContext")
.def_static("create", &::at::cuda::GreenContext::create)
.def("set_context", &::at::cuda::GreenContext::setContext)
.def("pop_context", &::at::cuda::GreenContext::popContext);
}

View File

@ -35,6 +35,7 @@ from .graphs import (
is_current_stream_capturing,
make_graphed_callables,
)
from .green_contexts import GreenContext
from .streams import Event, ExternalStream, Stream
@ -1844,6 +1845,7 @@ __all__ = [
"ExternalStream",
"Stream",
"StreamContext",
"GreenContext",
"amp",
"caching_allocator_alloc",
"caching_allocator_delete",

View File

@ -0,0 +1,42 @@
import torch
_GreenContext = object
SUPPORTED = False
if hasattr(torch._C, "_CUDAGreenContext"):
_GreenContext = torch._C._CUDAGreenContext # type: ignore[misc]
SUPPORTED = True
# Python shim helps Sphinx process docstrings more reliably.
class GreenContext(_GreenContext):
r"""Wrapper around a CUDA green context.
.. warning::
This API is in beta and may change in future releases.
"""
@staticmethod
def create(num_sms: int, device_id: int = 0) -> _GreenContext:
r"""Create a CUDA green context.
Arguments:
num_sms (int): The number of SMs to use in the green context.
device_id (int, optional): The device index of green context.
"""
if not SUPPORTED:
raise RuntimeError("PyTorch was not built with Green Context support!")
return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined]
# Note that these functions are bypassed by we define them here
# for Sphinx documentation purposes
def set_context(self) -> None:
r"""Make the green context the current context."""
return super().set_context() # type: ignore[misc]
def pop_context(self) -> None:
r"""Assuming the green context is the current context, pop it from the
context stack and restore the previous context.
"""
return super().pop_context() # type: ignore[misc]

View File

@ -81,6 +81,16 @@ def evaluate_platform_supports_efficient_attention():
def evaluate_platform_supports_cudnn_attention():
return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000)
def evaluate_platform_supports_green_context():
if IS_WINDOWS:
return False
if not _get_torch_cuda_version() >= (12, 8):
return False
driver_version = torch.utils.collect_env.get_nvidia_driver_version(torch.utils.collect_env.run)
if driver_version is None:
return False
return int(driver_version.split('.')[0]) >= 550
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention())
PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention())
@ -93,6 +103,8 @@ PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
PLATFORM_SUPPORTS_GREEN_CONTEXT: bool = LazyVal(lambda: evaluate_platform_supports_green_context())
def evaluate_platform_supports_fp8():
if torch.cuda.is_available():
if torch.version.hip: