mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CUDA] Add experimental green context support for SM carveout (#159104)
Low-level PyTorch APIs should be usable/stable enough at this point but we might move the underlying driver API usage a bit from here... Built on top of @drisspg 's branch Pull Request resolved: https://github.com/pytorch/pytorch/pull/159104 Approved by: https://github.com/ngimel, https://github.com/malfet, https://github.com/kwen2501 Co-authored-by: drisspg <drisspguessous@gmail.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
parent
0b58d87aec
commit
e64a814ae7
192
aten/src/ATen/cuda/CUDAGreenContext.cpp
Normal file
192
aten/src/ATen/cuda/CUDAGreenContext.cpp
Normal file
|
|
@ -0,0 +1,192 @@
|
||||||
|
#include <ATen/cuda/CUDAGreenContext.h>
|
||||||
|
|
||||||
|
namespace at::cuda {
|
||||||
|
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
int driver_version;
|
||||||
|
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||||
|
TORCH_CHECK(
|
||||||
|
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||||
|
CUcontext pctx = nullptr;
|
||||||
|
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||||
|
if (C10_UNLIKELY(!pctx)) {
|
||||||
|
TORCH_WARN(
|
||||||
|
"Attempted to create a green context but"
|
||||||
|
" there was no primary context! Creating a primary context...");
|
||||||
|
|
||||||
|
cudaFree(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUdevice device;
|
||||||
|
device_id_ = device_id;
|
||||||
|
C10_CUDA_DRIVER_CHECK(
|
||||||
|
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||||
|
|
||||||
|
// Get device resources
|
||||||
|
CUdevResource device_resource;
|
||||||
|
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||||
|
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||||
|
|
||||||
|
// Split resources
|
||||||
|
std::vector<CUdevResource> result(1);
|
||||||
|
auto result_data = result.data();
|
||||||
|
unsigned int nb_groups = 1;
|
||||||
|
CUdevResource remaining;
|
||||||
|
|
||||||
|
C10_CUDA_DRIVER_CHECK(
|
||||||
|
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||||
|
result_data,
|
||||||
|
&nb_groups,
|
||||||
|
&device_resource,
|
||||||
|
&remaining,
|
||||||
|
0, // default flags
|
||||||
|
num_sms));
|
||||||
|
|
||||||
|
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||||
|
|
||||||
|
// Generate resource descriptor
|
||||||
|
CUdevResourceDesc desc;
|
||||||
|
C10_CUDA_DRIVER_CHECK(
|
||||||
|
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||||
|
&desc, result_data, 1));
|
||||||
|
|
||||||
|
// Create green context
|
||||||
|
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||||
|
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||||
|
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||||
|
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||||
|
|
||||||
|
// Convert to regular context
|
||||||
|
C10_CUDA_DRIVER_CHECK(
|
||||||
|
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||||
|
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<GreenContext> GreenContext::create(
|
||||||
|
uint32_t num_sms,
|
||||||
|
std::optional<uint32_t> device_id) {
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
if (!device_id.has_value()) {
|
||||||
|
device_id = at::cuda::current_device();
|
||||||
|
}
|
||||||
|
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement move operations
|
||||||
|
GreenContext::GreenContext(GreenContext&& other) noexcept{
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
device_id_ = std::exchange(other.device_id_, -1);
|
||||||
|
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||||
|
context_ = std::exchange(other.context_, nullptr);
|
||||||
|
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
if (this != &other) {
|
||||||
|
// Clean up current resources
|
||||||
|
if (green_ctx_) {
|
||||||
|
CUcontext current = nullptr;
|
||||||
|
C10_CUDA_DRIVER_CHECK(
|
||||||
|
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
||||||
|
if (current == context_) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"attempting to overwrite current green ctx "
|
||||||
|
"when it is active!");
|
||||||
|
}
|
||||||
|
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take ownership of other's resources
|
||||||
|
device_id_ = std::exchange(other.device_id_, -1);
|
||||||
|
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||||
|
context_ = std::exchange(other.context_, nullptr);
|
||||||
|
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
GreenContext::~GreenContext() noexcept{
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
C10_CUDA_DRIVER_CHECK(
|
||||||
|
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the underlying CUDA context
|
||||||
|
CUcontext GreenContext::getContext() const {
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
return context_;
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the underlying green context
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
CUgreenCtx GreenContext::getGreenContext() const {
|
||||||
|
return green_ctx_;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Make this context current
|
||||||
|
void GreenContext::setContext() {
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
||||||
|
parent_stream_ = current_stream.stream();
|
||||||
|
|
||||||
|
at::cuda::CUDAEvent ev;
|
||||||
|
ev.record(current_stream);
|
||||||
|
|
||||||
|
CUcontext current = nullptr;
|
||||||
|
C10_CUDA_DRIVER_CHECK(
|
||||||
|
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
||||||
|
if (!current) {
|
||||||
|
C10_CUDA_DRIVER_CHECK(
|
||||||
|
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
|
||||||
|
} else {
|
||||||
|
C10_CUDA_DRIVER_CHECK(
|
||||||
|
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
|
||||||
|
}
|
||||||
|
// currently hardcodes the new green context to use the default stream
|
||||||
|
// TODO(eqy): consider creating a new stream if e.g., it allows interop
|
||||||
|
// with CUDA Graph captures etc.
|
||||||
|
auto default_stream = c10::cuda::getDefaultCUDAStream();
|
||||||
|
ev.block(default_stream);
|
||||||
|
c10::cuda::setCurrentCUDAStream(default_stream);
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void GreenContext::popContext() {
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
// see above note about stream being hardcoded to the default stream
|
||||||
|
at::cuda::CUDAEvent ev;
|
||||||
|
ev.record(c10::cuda::getCurrentCUDAStream());
|
||||||
|
CUcontext popped;
|
||||||
|
C10_CUDA_DRIVER_CHECK(
|
||||||
|
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
popped == context_, "expected popped context to be the current ctx");
|
||||||
|
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
} // namespace at::cuda
|
||||||
53
aten/src/ATen/cuda/CUDAGreenContext.h
Normal file
53
aten/src/ATen/cuda/CUDAGreenContext.h
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
#pragma once
|
||||||
|
#include <ATen/cuda/CUDAEvent.h>
|
||||||
|
|
||||||
|
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||||
|
#include <c10/cuda/driver_api.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <memory>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <vector>
|
||||||
|
#define CUDA_HAS_GREEN_CONTEXT 1
|
||||||
|
#else
|
||||||
|
#define CUDA_HAS_GREEN_CONTEXT 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace at::cuda {
|
||||||
|
|
||||||
|
class TORCH_CUDA_CPP_API GreenContext {
|
||||||
|
public:
|
||||||
|
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||||
|
|
||||||
|
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
|
||||||
|
|
||||||
|
// Delete copy constructor and assignment
|
||||||
|
GreenContext(const GreenContext&) = delete;
|
||||||
|
GreenContext& operator=(const GreenContext&) = delete;
|
||||||
|
|
||||||
|
// Implement move operations
|
||||||
|
GreenContext(GreenContext&& other) noexcept;
|
||||||
|
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||||
|
~GreenContext() noexcept;
|
||||||
|
|
||||||
|
// Get the underlying CUDA context
|
||||||
|
CUcontext getContext() const;
|
||||||
|
|
||||||
|
// Get the underlying green context
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
CUgreenCtx getGreenContext() const;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Make this context current
|
||||||
|
void setContext();
|
||||||
|
|
||||||
|
void popContext();
|
||||||
|
|
||||||
|
private:
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
int32_t device_id_ = -1;
|
||||||
|
CUgreenCtx green_ctx_ = nullptr;
|
||||||
|
CUcontext context_ = nullptr;
|
||||||
|
cudaStream_t parent_stream_ = nullptr;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
} // namespace at::cuda
|
||||||
|
|
@ -855,6 +855,7 @@ libtorch_python_cuda_core_sources = [
|
||||||
"torch/csrc/cuda/Stream.cpp",
|
"torch/csrc/cuda/Stream.cpp",
|
||||||
"torch/csrc/cuda/Graph.cpp",
|
"torch/csrc/cuda/Graph.cpp",
|
||||||
"torch/csrc/cuda/MemPool.cpp",
|
"torch/csrc/cuda/MemPool.cpp",
|
||||||
|
"torch/csrc/cuda/GreenContext.cpp",
|
||||||
"torch/csrc/cuda/shared/cudart.cpp",
|
"torch/csrc/cuda/shared/cudart.cpp",
|
||||||
"torch/csrc/cuda/shared/nvtx.cpp",
|
"torch/csrc/cuda/shared/nvtx.cpp",
|
||||||
"torch/csrc/cuda/utils.cpp",
|
"torch/csrc/cuda/utils.cpp",
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,17 @@
|
||||||
|
|
||||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
|
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
|
||||||
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
|
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
|
||||||
|
_(cuCtxFromGreenCtx, 12080) \
|
||||||
|
_(cuCtxGetCurrent, 12080) \
|
||||||
|
_(cuCtxPopCurrent, 12080) \
|
||||||
|
_(cuCtxPushCurrent, 12080) \
|
||||||
|
_(cuCtxSetCurrent, 12080) \
|
||||||
|
_(cuGreenCtxCreate, 12080) \
|
||||||
|
_(cuGreenCtxDestroy, 12080) \
|
||||||
|
_(cuDevSmResourceSplitByCount, 12080) \
|
||||||
|
_(cuDeviceGet, 12080) \
|
||||||
|
_(cuDeviceGetDevResource, 12080) \
|
||||||
|
_(cuDevResourceGenerateDesc, 12080) \
|
||||||
_(cuMulticastAddDevice, 12030) \
|
_(cuMulticastAddDevice, 12030) \
|
||||||
_(cuMulticastBindMem, 12030) \
|
_(cuMulticastBindMem, 12030) \
|
||||||
_(cuMulticastCreate, 12030) \
|
_(cuMulticastCreate, 12030) \
|
||||||
|
|
|
||||||
|
|
@ -607,6 +607,12 @@ if(USE_CUDA)
|
||||||
set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a")
|
set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
if(NOT WIN32)
|
||||||
|
set_source_files_properties(
|
||||||
|
${TORCH_ROOT}/aten/src/ATen/cuda/CUDAGreenContext.cpp
|
||||||
|
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
|
||||||
|
)
|
||||||
|
endif()
|
||||||
set_source_files_properties(
|
set_source_files_properties(
|
||||||
${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
|
${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
|
||||||
PROPERTIES COMPILE_DEFINITIONS "NVRTC_SHORTHASH=${CUDA_NVRTC_SHORTHASH}"
|
PROPERTIES COMPILE_DEFINITIONS "NVRTC_SHORTHASH=${CUDA_NVRTC_SHORTHASH}"
|
||||||
|
|
|
||||||
|
|
@ -258,6 +258,28 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Green Contexts (experimental)
|
||||||
|
|
||||||
|
`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs
|
||||||
|
to enable more general carveout of SM resources for CUDA kernels.
|
||||||
|
|
||||||
|
These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8.
|
||||||
|
|
||||||
|
See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these.
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. currentmodule:: torch.cuda.green_contexts
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
GreenContext
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
% This module needs to be documented. Adding here in the meantime
|
% This module needs to be documented. Adding here in the meantime
|
||||||
|
|
||||||
% for tracking purposes
|
% for tracking purposes
|
||||||
|
|
@ -270,6 +292,10 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
|
||||||
.. py:module:: torch.cuda.gds
|
.. py:module:: torch.cuda.gds
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. py:module:: torch.cuda.green_contexts
|
||||||
|
```
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. py:module:: torch.cuda.jiterator
|
.. py:module:: torch.cuda.jiterator
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# Owner(s): ["module: linear algebra"]
|
# Owner(s): ["module: linear algebra"]
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
@ -15,6 +16,7 @@ from torch.quantization._quantized_conversions import (
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_cuda import (
|
from torch.testing._internal.common_cuda import (
|
||||||
PLATFORM_SUPPORTS_BF16,
|
PLATFORM_SUPPORTS_BF16,
|
||||||
|
PLATFORM_SUPPORTS_GREEN_CONTEXT,
|
||||||
SM53OrLater,
|
SM53OrLater,
|
||||||
SM80OrLater,
|
SM80OrLater,
|
||||||
SM90OrLater,
|
SM90OrLater,
|
||||||
|
|
@ -40,6 +42,7 @@ from torch.testing._internal.common_utils import (
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
runOnRocmArch,
|
runOnRocmArch,
|
||||||
|
serialTest,
|
||||||
skipIfRocm,
|
skipIfRocm,
|
||||||
TEST_CUDA,
|
TEST_CUDA,
|
||||||
TEST_WITH_ROCM,
|
TEST_WITH_ROCM,
|
||||||
|
|
@ -855,6 +858,29 @@ class TestMatmulCuda(InductorTestCase):
|
||||||
op(a, mismatch_batch_dim_b, out_dtype=torch.float32)
|
op(a, mismatch_batch_dim_b, out_dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_GREEN_CONTEXT, "Green contexts are not supported")
|
||||||
|
@serialTest()
|
||||||
|
def test_greencontext_carveout(self):
|
||||||
|
a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16)
|
||||||
|
ctx = torch.cuda.green_contexts.GreenContext.create(1, 0)
|
||||||
|
ctx.set_context()
|
||||||
|
torch.matmul(a, a)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
partial_res = torch.matmul(a, a)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
ctx.pop_context()
|
||||||
|
torch.matmul(a, a)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
t2 = time.perf_counter()
|
||||||
|
full_res = torch.matmul(a, a)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
t3 = time.perf_counter()
|
||||||
|
self.assertEqual(partial_res, full_res)
|
||||||
|
self.assertGreater(t1 - t0, t3 - t2)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
||||||
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
|
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
|
||||||
|
|
|
||||||
|
|
@ -123,6 +123,7 @@ class TestPublicBindings(TestCase):
|
||||||
"FutureType",
|
"FutureType",
|
||||||
"Generator",
|
"Generator",
|
||||||
"GeneratorType",
|
"GeneratorType",
|
||||||
|
"GreenContext",
|
||||||
"get_autocast_cpu_dtype",
|
"get_autocast_cpu_dtype",
|
||||||
"get_autocast_dtype",
|
"get_autocast_dtype",
|
||||||
"get_autocast_ipu_dtype",
|
"get_autocast_ipu_dtype",
|
||||||
|
|
|
||||||
|
|
@ -2792,3 +2792,17 @@ class _StaticCudaLauncher:
|
||||||
args: tuple[Any, ...],
|
args: tuple[Any, ...],
|
||||||
stream: _int,
|
stream: _int,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
|
# Defined in torch/csrc/cuda/GreenContext.cpp
|
||||||
|
class GreenContext:
|
||||||
|
@staticmethod
|
||||||
|
def create(
|
||||||
|
num_sms: _int,
|
||||||
|
device_id: _int,
|
||||||
|
) -> GreenContext: ...
|
||||||
|
def set_context(
|
||||||
|
self,
|
||||||
|
) -> None: ...
|
||||||
|
def pop_context(
|
||||||
|
self,
|
||||||
|
) -> None: ...
|
||||||
|
|
|
||||||
|
|
@ -1957,6 +1957,7 @@ void THCPStream_init(PyObject* module);
|
||||||
void THCPEvent_init(PyObject* module);
|
void THCPEvent_init(PyObject* module);
|
||||||
void THCPGraph_init(PyObject* module);
|
void THCPGraph_init(PyObject* module);
|
||||||
void THCPMemPool_init(PyObject* module);
|
void THCPMemPool_init(PyObject* module);
|
||||||
|
void THCPGreenContext_init(PyObject* module);
|
||||||
PyMethodDef* THCPModule_methods();
|
PyMethodDef* THCPModule_methods();
|
||||||
namespace torch::cuda {
|
namespace torch::cuda {
|
||||||
void initModule(PyObject* module);
|
void initModule(PyObject* module);
|
||||||
|
|
@ -2184,6 +2185,7 @@ PyObject* initModule() {
|
||||||
THCPEvent_init(module);
|
THCPEvent_init(module);
|
||||||
THCPGraph_init(module);
|
THCPGraph_init(module);
|
||||||
THCPMemPool_init(module);
|
THCPMemPool_init(module);
|
||||||
|
THCPGreenContext_init(module);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef USE_XPU
|
#ifdef USE_XPU
|
||||||
|
|
|
||||||
13
torch/csrc/cuda/GreenContext.cpp
Normal file
13
torch/csrc/cuda/GreenContext.cpp
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
#include <ATen/cuda/CUDAGreenContext.h>
|
||||||
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
|
// Cargo culted partially from csrc/cuda/Stream.cpp
|
||||||
|
|
||||||
|
void THCPGreenContext_init(PyObject* module) {
|
||||||
|
auto m = py::handle(module).cast<py::module>();
|
||||||
|
py::class_<at::cuda::GreenContext>(m, "_CUDAGreenContext")
|
||||||
|
.def_static("create", &::at::cuda::GreenContext::create)
|
||||||
|
.def("set_context", &::at::cuda::GreenContext::setContext)
|
||||||
|
.def("pop_context", &::at::cuda::GreenContext::popContext);
|
||||||
|
}
|
||||||
|
|
@ -35,6 +35,7 @@ from .graphs import (
|
||||||
is_current_stream_capturing,
|
is_current_stream_capturing,
|
||||||
make_graphed_callables,
|
make_graphed_callables,
|
||||||
)
|
)
|
||||||
|
from .green_contexts import GreenContext
|
||||||
from .streams import Event, ExternalStream, Stream
|
from .streams import Event, ExternalStream, Stream
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1844,6 +1845,7 @@ __all__ = [
|
||||||
"ExternalStream",
|
"ExternalStream",
|
||||||
"Stream",
|
"Stream",
|
||||||
"StreamContext",
|
"StreamContext",
|
||||||
|
"GreenContext",
|
||||||
"amp",
|
"amp",
|
||||||
"caching_allocator_alloc",
|
"caching_allocator_alloc",
|
||||||
"caching_allocator_delete",
|
"caching_allocator_delete",
|
||||||
|
|
|
||||||
42
torch/cuda/green_contexts.py
Normal file
42
torch/cuda/green_contexts.py
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
_GreenContext = object
|
||||||
|
SUPPORTED = False
|
||||||
|
|
||||||
|
if hasattr(torch._C, "_CUDAGreenContext"):
|
||||||
|
_GreenContext = torch._C._CUDAGreenContext # type: ignore[misc]
|
||||||
|
SUPPORTED = True
|
||||||
|
|
||||||
|
|
||||||
|
# Python shim helps Sphinx process docstrings more reliably.
|
||||||
|
class GreenContext(_GreenContext):
|
||||||
|
r"""Wrapper around a CUDA green context.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This API is in beta and may change in future releases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(num_sms: int, device_id: int = 0) -> _GreenContext:
|
||||||
|
r"""Create a CUDA green context.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
num_sms (int): The number of SMs to use in the green context.
|
||||||
|
device_id (int, optional): The device index of green context.
|
||||||
|
"""
|
||||||
|
if not SUPPORTED:
|
||||||
|
raise RuntimeError("PyTorch was not built with Green Context support!")
|
||||||
|
return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# Note that these functions are bypassed by we define them here
|
||||||
|
# for Sphinx documentation purposes
|
||||||
|
def set_context(self) -> None:
|
||||||
|
r"""Make the green context the current context."""
|
||||||
|
return super().set_context() # type: ignore[misc]
|
||||||
|
|
||||||
|
def pop_context(self) -> None:
|
||||||
|
r"""Assuming the green context is the current context, pop it from the
|
||||||
|
context stack and restore the previous context.
|
||||||
|
"""
|
||||||
|
return super().pop_context() # type: ignore[misc]
|
||||||
|
|
@ -81,6 +81,16 @@ def evaluate_platform_supports_efficient_attention():
|
||||||
def evaluate_platform_supports_cudnn_attention():
|
def evaluate_platform_supports_cudnn_attention():
|
||||||
return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000)
|
return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000)
|
||||||
|
|
||||||
|
def evaluate_platform_supports_green_context():
|
||||||
|
if IS_WINDOWS:
|
||||||
|
return False
|
||||||
|
if not _get_torch_cuda_version() >= (12, 8):
|
||||||
|
return False
|
||||||
|
driver_version = torch.utils.collect_env.get_nvidia_driver_version(torch.utils.collect_env.run)
|
||||||
|
if driver_version is None:
|
||||||
|
return False
|
||||||
|
return int(driver_version.split('.')[0]) >= 550
|
||||||
|
|
||||||
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
|
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
|
||||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention())
|
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention())
|
||||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention())
|
PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention())
|
||||||
|
|
@ -93,6 +103,8 @@ PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
|
||||||
|
|
||||||
PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
|
PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
|
||||||
|
|
||||||
|
PLATFORM_SUPPORTS_GREEN_CONTEXT: bool = LazyVal(lambda: evaluate_platform_supports_green_context())
|
||||||
|
|
||||||
def evaluate_platform_supports_fp8():
|
def evaluate_platform_supports_fp8():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
if torch.version.hip:
|
if torch.version.hip:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user