diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 380e7939ff7..8910e581a1a 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -51,17 +51,6 @@ #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \ - _(cuCtxFromGreenCtx, 12080) \ - _(cuCtxGetCurrent, 12080) \ - _(cuCtxPopCurrent, 12080) \ - _(cuCtxPushCurrent, 12080) \ - _(cuCtxSetCurrent, 12080) \ - _(cuGreenCtxCreate, 12080) \ - _(cuGreenCtxDestroy, 12080) \ - _(cuDevSmResourceSplitByCount, 12080) \ - _(cuDeviceGet, 12080) \ - _(cuDeviceGetDevResource, 12080) \ - _(cuDevResourceGenerateDesc, 12080) \ _(cuMulticastAddDevice, 12030) \ _(cuMulticastBindMem, 12030) \ _(cuMulticastCreate, 12030) \ diff --git a/docs/source/cuda.md b/docs/source/cuda.md index 26870c3dcc3..09cf443cf06 100644 --- a/docs/source/cuda.md +++ b/docs/source/cuda.md @@ -262,28 +262,6 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t ``` -## Green Contexts (experimental) - -`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs -to enable more general carveout of SM resources for CUDA kernels. - -These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8. - -See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these. - -```{eval-rst} -.. currentmodule:: torch.cuda.green_contexts -``` - -```{eval-rst} -.. autosummary:: - :toctree: generated - :nosignatures: - - GreenContext -``` - - % This module needs to be documented. Adding here in the meantime % for tracking purposes @@ -296,10 +274,6 @@ See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example .. py:module:: torch.cuda.gds ``` -```{eval-rst} -.. py:module:: torch.cuda.green_contexts -``` - ```{eval-rst} .. py:module:: torch.cuda.jiterator ``` @@ -325,4 +299,4 @@ See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example :hidden: cuda.aliases.md -``` +``` \ No newline at end of file diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index fa7e761c249..b1f7f91a34d 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -1,7 +1,6 @@ # Owner(s): ["module: linear algebra"] import contextlib -import time import unittest from itertools import product from functools import partial @@ -847,28 +846,6 @@ class TestMatmulCuda(InductorTestCase): op(a, mismatch_batch_dim_b, out_dtype=torch.float32) - @unittest.skipIf(not _get_torch_cuda_version() >= (12, 8), "Green Context only tested on 12.8+") - def test_greencontext_carveout(self): - a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16) - ctx = torch.cuda.green_contexts.GreenContext.create(1, 0) - ctx.make_current() - torch.matmul(a, a) - torch.cuda.synchronize() - t0 = time.perf_counter() - partial_res = torch.matmul(a, a) - torch.cuda.synchronize() - t1 = time.perf_counter() - ctx.pop_current() - torch.matmul(a, a) - torch.cuda.synchronize() - t2 = time.perf_counter() - full_res = torch.matmul(a, a) - torch.cuda.synchronize() - t3 = time.perf_counter() - self.assertEqual(partial_res, full_res) - self.assertGreater(t1 - t0, t3 - t2) - - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") @unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index dff4e9c014c..fa21705c76e 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -123,7 +123,6 @@ class TestPublicBindings(TestCase): "FutureType", "Generator", "GeneratorType", - "GreenContext", "get_autocast_cpu_dtype", "get_autocast_dtype", "get_autocast_ipu_dtype", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 29ca4af21de..b3c7621aaf8 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2777,17 +2777,3 @@ class _StaticCudaLauncher: args: tuple[Any, ...], stream: _int, ) -> None: ... - -# Defined in torch/csrc/cuda/green_context.h -class GreenContext: - @staticmethod - def create( - num_sms: _int, - device_id: _int, - ) -> GreenContext: ... - def make_current( - self, - ) -> None: ... - def pop_current( - self, - ) -> None: ... diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 10db3924da7..84fd3d0d714 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #ifdef USE_NCCL @@ -1492,13 +1491,6 @@ static void registerCudaPluggableAllocator(PyObject* module) { addStorageDeleterFns(storages_to_add_deleters_to, delta); }); } -static void initGreenContext(PyObject* module) { - auto m = py::handle(module).cast(); - py::class_(m, "GreenContext") - .def_static("create", &GreenContext::create) - .def("make_current", &GreenContext::makeCurrent) - .def("pop_current", &GreenContext::popCurrent); -} static void bindGetDeviceProperties(PyObject* module) { // Add method to torch.cuda @@ -2223,7 +2215,6 @@ void initModule(PyObject* module) { registerCudaDeviceProperties(module); registerCudaPluggableAllocator(module); initCudaMethodBindings(module); - initGreenContext(module); } } // namespace torch::cuda diff --git a/torch/csrc/cuda/green_context.h b/torch/csrc/cuda/green_context.h deleted file mode 100644 index 80a39cd3c21..00000000000 --- a/torch/csrc/cuda/green_context.h +++ /dev/null @@ -1,213 +0,0 @@ -#pragma once -#include -#if defined(CUDA_VERSION) && !defined(USE_ROCM) -#include -#include -#include -#include -#include -#endif - -class GreenContext { - public: - GreenContext(int device_id, unsigned int num_sms) { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) - int driver_version; - C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); - TORCH_CHECK( - driver_version >= 12080, "cuda driver too old to use green context!"); - CUcontext pctx = nullptr; - C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx)); - if (C10_UNLIKELY(!pctx)) { - TORCH_WARN( - "Attempted to create a green context but" - " there was no primary context! Creating a primary context..."); - - cudaFree(0); - } - - CUdevice device; - device_id_ = device_id; - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); - - // Get device resources - CUdevResource device_resource; - C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( - device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); - - // Split resources - std::vector result(1); - auto result_data = result.data(); - unsigned int nb_groups = 1; - CUdevResource remaining; - - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( - result_data, - &nb_groups, - &device_resource, - &remaining, - 0, // default flags - num_sms)); - - TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); - - // Generate resource descriptor - CUdevResourceDesc desc; - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( - &desc, result_data, 1)); - - // Create green context - // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html - C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( - &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM)); - - // Convert to regular context - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_)); - TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!"); -#else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); -#endif - } - - static std::unique_ptr create( - unsigned int num_sms, - std::optional device_id) { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) - if (!device_id.has_value()) { - device_id = at::cuda::current_device(); - } - return std::make_unique(device_id.value(), num_sms); -#else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); -#endif - } - - // Delete copy constructor and assignment - GreenContext(const GreenContext&) = delete; - GreenContext& operator=(const GreenContext&) = delete; - - // Implement move operations - GreenContext(GreenContext&& other) noexcept { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) - device_id_ = std::exchange(other.device_id_, -1); - green_ctx_ = std::exchange(other.green_ctx_, nullptr); - context_ = std::exchange(other.context_, nullptr); - parent_stream_ = std::exchange(other.parent_stream_, nullptr); -#else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); -#endif - } - - GreenContext& operator=(GreenContext&& other) noexcept { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) - if (this != &other) { - // Clean up current resources - if (green_ctx_) { - CUcontext current = nullptr; - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); - if (current == context_) { - TORCH_CHECK( - false, - "attempting to overwrite current green ctx " - "when it is active!"); - } - C10_CUDA_DRIVER_CHECK(cuGreenCtxDestroy(green_ctx_)); - } - - // Take ownership of other's resources - device_id_ = std::exchange(other.device_id_, -1); - green_ctx_ = std::exchange(other.green_ctx_, nullptr); - context_ = std::exchange(other.context_, nullptr); - parent_stream_ = std::exchange(other.parent_stream_, nullptr); - } - return *this; -#else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); -#endif - } - - ~GreenContext() noexcept { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); -#else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); -#endif - } - - // Get the underlying CUDA context - CUcontext getContext() const { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) - return context_; -#else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); -#endif - } - - // Get the underlying green context -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) - CUgreenCtx getGreenContext() const { - return green_ctx_; - } -#endif - - // Make this context current - void makeCurrent() { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) - auto current_stream = c10::cuda::getCurrentCUDAStream(); - parent_stream_ = current_stream.stream(); - - at::cuda::CUDAEvent ev; - ev.record(current_stream); - - CUcontext current = nullptr; - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); - if (!current) { - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_)); - } else { - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_)); - } - // currently hardcodes the new green context to use the default stream - // TODO(eqy): consider creating a new stream if e.g., it allows interop - // with CUDA Graph captures etc. - auto default_stream = c10::cuda::getDefaultCUDAStream(); - ev.block(default_stream); - c10::cuda::setCurrentCUDAStream(default_stream); -#else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); -#endif - } - - void popCurrent() { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) - // see above note about stream being hardcoded to the default stream - at::cuda::CUDAEvent ev; - ev.record(c10::cuda::getCurrentCUDAStream()); - CUcontext popped; - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped)); - TORCH_INTERNAL_ASSERT( - popped == context_, "expected popped context to be the current ctx"); - ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_)); -#else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); -#endif - } - - private: -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM) - int device_id_ = -1; - CUgreenCtx green_ctx_ = nullptr; - CUcontext context_ = nullptr; - cudaStream_t parent_stream_ = nullptr; -#endif -}; diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 954d4d9ff58..bf562b68f73 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -35,7 +35,6 @@ from .graphs import ( is_current_stream_capturing, make_graphed_callables, ) -from .green_contexts import GreenContext from .streams import Event, ExternalStream, Stream @@ -1832,7 +1831,6 @@ __all__ = [ "ExternalStream", "Stream", "StreamContext", - "GreenContext", "amp", "caching_allocator_alloc", "caching_allocator_delete", diff --git a/torch/cuda/green_contexts.py b/torch/cuda/green_contexts.py deleted file mode 100644 index 743dd323b5a..00000000000 --- a/torch/cuda/green_contexts.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch - - -_GreenContext = object -SUPPORTED = False - -if hasattr(torch._C, "GreenContext"): - _GreenContext = torch._C.GreenContext # type: ignore[misc] - SUPPORTED = True - - -# Python shim helps Sphinx process docstrings more reliably. -class GreenContext(_GreenContext): - r"""Wrapper around a CUDA green context. - - .. warning:: - This API is in beta and may change in future releases. - """ - - @staticmethod - def create(num_sms: int, device_id: int = 0) -> _GreenContext: - r"""Create a CUDA green context. - - Arguments: - num_sms (int): The number of SMs to use in the green context. - device_id (int, optional): The device index of green context. - """ - if not SUPPORTED: - raise RuntimeError("PyTorch was not built with Green Context support!") - return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined] - - # Note that these functions are bypassed by we define them here - # for Sphinx documentation purposes - def make_current(self) -> None: - r"""Make the green context the current context.""" - return super().make_current() # type: ignore[misc] - - def pop_current(self) -> None: - r"""Assuming the green context is the current context, pop it from the - context stack and restore the previous context. - """ - return super().pop_current() # type: ignore[misc]