mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[CUDA] Add experimental green context support for SM carveout (#159104)
Low-level PyTorch APIs should be usable/stable enough at this point but we might move the underlying driver API usage a bit from here... Built on top of @drisspg 's branch Pull Request resolved: https://github.com/pytorch/pytorch/pull/159104 Approved by: https://github.com/ngimel, https://github.com/malfet, https://github.com/kwen2501 Co-authored-by: drisspg <drisspguessous@gmail.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
parent
0b58d87aec
commit
e64a814ae7
192
aten/src/ATen/cuda/CUDAGreenContext.cpp
Normal file
192
aten/src/ATen/cuda/CUDAGreenContext.cpp
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
#include <ATen/cuda/CUDAGreenContext.h>
|
||||
|
||||
namespace at::cuda {
|
||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
|
||||
cudaFree(0);
|
||||
}
|
||||
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::unique_ptr<GreenContext> GreenContext::create(
|
||||
uint32_t num_sms,
|
||||
std::optional<uint32_t> device_id) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
if (!device_id.has_value()) {
|
||||
device_id = at::cuda::current_device();
|
||||
}
|
||||
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Implement move operations
|
||||
GreenContext::GreenContext(GreenContext&& other) noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
if (this != &other) {
|
||||
// Clean up current resources
|
||||
if (green_ctx_) {
|
||||
CUcontext current = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
||||
if (current == context_) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"attempting to overwrite current green ctx "
|
||||
"when it is active!");
|
||||
}
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
}
|
||||
|
||||
// Take ownership of other's resources
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
||||
}
|
||||
return *this;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
GreenContext::~GreenContext() noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext GreenContext::getContext() const {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
return context_;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx GreenContext::getGreenContext() const {
|
||||
return green_ctx_;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void GreenContext::setContext() {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
||||
parent_stream_ = current_stream.stream();
|
||||
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(current_stream);
|
||||
|
||||
CUcontext current = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
||||
if (!current) {
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
|
||||
} else {
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
|
||||
}
|
||||
// currently hardcodes the new green context to use the default stream
|
||||
// TODO(eqy): consider creating a new stream if e.g., it allows interop
|
||||
// with CUDA Graph captures etc.
|
||||
auto default_stream = c10::cuda::getDefaultCUDAStream();
|
||||
ev.block(default_stream);
|
||||
c10::cuda::setCurrentCUDAStream(default_stream);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
void GreenContext::popContext() {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
// see above note about stream being hardcoded to the default stream
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(c10::cuda::getCurrentCUDAStream());
|
||||
CUcontext popped;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
popped == context_, "expected popped context to be the current ctx");
|
||||
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
} // namespace at::cuda
|
||||
53
aten/src/ATen/cuda/CUDAGreenContext.h
Normal file
53
aten/src/ATen/cuda/CUDAGreenContext.h
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
#pragma once
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define CUDA_HAS_GREEN_CONTEXT 1
|
||||
#else
|
||||
#define CUDA_HAS_GREEN_CONTEXT 0
|
||||
#endif
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
class TORCH_CUDA_CPP_API GreenContext {
|
||||
public:
|
||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||
|
||||
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
|
||||
|
||||
// Delete copy constructor and assignment
|
||||
GreenContext(const GreenContext&) = delete;
|
||||
GreenContext& operator=(const GreenContext&) = delete;
|
||||
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept;
|
||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||
~GreenContext() noexcept;
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext getContext() const;
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx getGreenContext() const;
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void setContext();
|
||||
|
||||
void popContext();
|
||||
|
||||
private:
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int32_t device_id_ = -1;
|
||||
CUgreenCtx green_ctx_ = nullptr;
|
||||
CUcontext context_ = nullptr;
|
||||
cudaStream_t parent_stream_ = nullptr;
|
||||
#endif
|
||||
};
|
||||
} // namespace at::cuda
|
||||
|
|
@ -855,6 +855,7 @@ libtorch_python_cuda_core_sources = [
|
|||
"torch/csrc/cuda/Stream.cpp",
|
||||
"torch/csrc/cuda/Graph.cpp",
|
||||
"torch/csrc/cuda/MemPool.cpp",
|
||||
"torch/csrc/cuda/GreenContext.cpp",
|
||||
"torch/csrc/cuda/shared/cudart.cpp",
|
||||
"torch/csrc/cuda/shared/nvtx.cpp",
|
||||
"torch/csrc/cuda/utils.cpp",
|
||||
|
|
|
|||
|
|
@ -51,6 +51,17 @@
|
|||
|
||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
|
||||
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
|
||||
_(cuCtxFromGreenCtx, 12080) \
|
||||
_(cuCtxGetCurrent, 12080) \
|
||||
_(cuCtxPopCurrent, 12080) \
|
||||
_(cuCtxPushCurrent, 12080) \
|
||||
_(cuCtxSetCurrent, 12080) \
|
||||
_(cuGreenCtxCreate, 12080) \
|
||||
_(cuGreenCtxDestroy, 12080) \
|
||||
_(cuDevSmResourceSplitByCount, 12080) \
|
||||
_(cuDeviceGet, 12080) \
|
||||
_(cuDeviceGetDevResource, 12080) \
|
||||
_(cuDevResourceGenerateDesc, 12080) \
|
||||
_(cuMulticastAddDevice, 12030) \
|
||||
_(cuMulticastBindMem, 12030) \
|
||||
_(cuMulticastCreate, 12030) \
|
||||
|
|
|
|||
|
|
@ -607,6 +607,12 @@ if(USE_CUDA)
|
|||
set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a")
|
||||
endif()
|
||||
endif()
|
||||
if(NOT WIN32)
|
||||
set_source_files_properties(
|
||||
${TORCH_ROOT}/aten/src/ATen/cuda/CUDAGreenContext.cpp
|
||||
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
|
||||
)
|
||||
endif()
|
||||
set_source_files_properties(
|
||||
${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
|
||||
PROPERTIES COMPILE_DEFINITIONS "NVRTC_SHORTHASH=${CUDA_NVRTC_SHORTHASH}"
|
||||
|
|
|
|||
|
|
@ -258,6 +258,28 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
|
|||
|
||||
```
|
||||
|
||||
## Green Contexts (experimental)
|
||||
|
||||
`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs
|
||||
to enable more general carveout of SM resources for CUDA kernels.
|
||||
|
||||
These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8.
|
||||
|
||||
See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these.
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.cuda.green_contexts
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
GreenContext
|
||||
```
|
||||
|
||||
|
||||
% This module needs to be documented. Adding here in the meantime
|
||||
|
||||
% for tracking purposes
|
||||
|
|
@ -270,6 +292,10 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
|
|||
.. py:module:: torch.cuda.gds
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:module:: torch.cuda.green_contexts
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:module:: torch.cuda.jiterator
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["module: linear algebra"]
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
import unittest
|
||||
from itertools import product
|
||||
from functools import partial
|
||||
|
|
@ -15,6 +16,7 @@ from torch.quantization._quantized_conversions import (
|
|||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_BF16,
|
||||
PLATFORM_SUPPORTS_GREEN_CONTEXT,
|
||||
SM53OrLater,
|
||||
SM80OrLater,
|
||||
SM90OrLater,
|
||||
|
|
@ -40,6 +42,7 @@ from torch.testing._internal.common_utils import (
|
|||
parametrize,
|
||||
run_tests,
|
||||
runOnRocmArch,
|
||||
serialTest,
|
||||
skipIfRocm,
|
||||
TEST_CUDA,
|
||||
TEST_WITH_ROCM,
|
||||
|
|
@ -855,6 +858,29 @@ class TestMatmulCuda(InductorTestCase):
|
|||
op(a, mismatch_batch_dim_b, out_dtype=torch.float32)
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_GREEN_CONTEXT, "Green contexts are not supported")
|
||||
@serialTest()
|
||||
def test_greencontext_carveout(self):
|
||||
a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16)
|
||||
ctx = torch.cuda.green_contexts.GreenContext.create(1, 0)
|
||||
ctx.set_context()
|
||||
torch.matmul(a, a)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
partial_res = torch.matmul(a, a)
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
ctx.pop_context()
|
||||
torch.matmul(a, a)
|
||||
torch.cuda.synchronize()
|
||||
t2 = time.perf_counter()
|
||||
full_res = torch.matmul(a, a)
|
||||
torch.cuda.synchronize()
|
||||
t3 = time.perf_counter()
|
||||
self.assertEqual(partial_res, full_res)
|
||||
self.assertGreater(t1 - t0, t3 - t2)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
||||
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
|
||||
|
|
|
|||
|
|
@ -123,6 +123,7 @@ class TestPublicBindings(TestCase):
|
|||
"FutureType",
|
||||
"Generator",
|
||||
"GeneratorType",
|
||||
"GreenContext",
|
||||
"get_autocast_cpu_dtype",
|
||||
"get_autocast_dtype",
|
||||
"get_autocast_ipu_dtype",
|
||||
|
|
|
|||
|
|
@ -2792,3 +2792,17 @@ class _StaticCudaLauncher:
|
|||
args: tuple[Any, ...],
|
||||
stream: _int,
|
||||
) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/cuda/GreenContext.cpp
|
||||
class GreenContext:
|
||||
@staticmethod
|
||||
def create(
|
||||
num_sms: _int,
|
||||
device_id: _int,
|
||||
) -> GreenContext: ...
|
||||
def set_context(
|
||||
self,
|
||||
) -> None: ...
|
||||
def pop_context(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
|
|
|||
|
|
@ -1957,6 +1957,7 @@ void THCPStream_init(PyObject* module);
|
|||
void THCPEvent_init(PyObject* module);
|
||||
void THCPGraph_init(PyObject* module);
|
||||
void THCPMemPool_init(PyObject* module);
|
||||
void THCPGreenContext_init(PyObject* module);
|
||||
PyMethodDef* THCPModule_methods();
|
||||
namespace torch::cuda {
|
||||
void initModule(PyObject* module);
|
||||
|
|
@ -2184,6 +2185,7 @@ PyObject* initModule() {
|
|||
THCPEvent_init(module);
|
||||
THCPGraph_init(module);
|
||||
THCPMemPool_init(module);
|
||||
THCPGreenContext_init(module);
|
||||
#endif
|
||||
|
||||
#ifdef USE_XPU
|
||||
|
|
|
|||
13
torch/csrc/cuda/GreenContext.cpp
Normal file
13
torch/csrc/cuda/GreenContext.cpp
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
#include <ATen/cuda/CUDAGreenContext.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
// Cargo culted partially from csrc/cuda/Stream.cpp
|
||||
|
||||
void THCPGreenContext_init(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
py::class_<at::cuda::GreenContext>(m, "_CUDAGreenContext")
|
||||
.def_static("create", &::at::cuda::GreenContext::create)
|
||||
.def("set_context", &::at::cuda::GreenContext::setContext)
|
||||
.def("pop_context", &::at::cuda::GreenContext::popContext);
|
||||
}
|
||||
|
|
@ -35,6 +35,7 @@ from .graphs import (
|
|||
is_current_stream_capturing,
|
||||
make_graphed_callables,
|
||||
)
|
||||
from .green_contexts import GreenContext
|
||||
from .streams import Event, ExternalStream, Stream
|
||||
|
||||
|
||||
|
|
@ -1844,6 +1845,7 @@ __all__ = [
|
|||
"ExternalStream",
|
||||
"Stream",
|
||||
"StreamContext",
|
||||
"GreenContext",
|
||||
"amp",
|
||||
"caching_allocator_alloc",
|
||||
"caching_allocator_delete",
|
||||
|
|
|
|||
42
torch/cuda/green_contexts.py
Normal file
42
torch/cuda/green_contexts.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
import torch
|
||||
|
||||
|
||||
_GreenContext = object
|
||||
SUPPORTED = False
|
||||
|
||||
if hasattr(torch._C, "_CUDAGreenContext"):
|
||||
_GreenContext = torch._C._CUDAGreenContext # type: ignore[misc]
|
||||
SUPPORTED = True
|
||||
|
||||
|
||||
# Python shim helps Sphinx process docstrings more reliably.
|
||||
class GreenContext(_GreenContext):
|
||||
r"""Wrapper around a CUDA green context.
|
||||
|
||||
.. warning::
|
||||
This API is in beta and may change in future releases.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(num_sms: int, device_id: int = 0) -> _GreenContext:
|
||||
r"""Create a CUDA green context.
|
||||
|
||||
Arguments:
|
||||
num_sms (int): The number of SMs to use in the green context.
|
||||
device_id (int, optional): The device index of green context.
|
||||
"""
|
||||
if not SUPPORTED:
|
||||
raise RuntimeError("PyTorch was not built with Green Context support!")
|
||||
return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined]
|
||||
|
||||
# Note that these functions are bypassed by we define them here
|
||||
# for Sphinx documentation purposes
|
||||
def set_context(self) -> None:
|
||||
r"""Make the green context the current context."""
|
||||
return super().set_context() # type: ignore[misc]
|
||||
|
||||
def pop_context(self) -> None:
|
||||
r"""Assuming the green context is the current context, pop it from the
|
||||
context stack and restore the previous context.
|
||||
"""
|
||||
return super().pop_context() # type: ignore[misc]
|
||||
|
|
@ -81,6 +81,16 @@ def evaluate_platform_supports_efficient_attention():
|
|||
def evaluate_platform_supports_cudnn_attention():
|
||||
return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000)
|
||||
|
||||
def evaluate_platform_supports_green_context():
|
||||
if IS_WINDOWS:
|
||||
return False
|
||||
if not _get_torch_cuda_version() >= (12, 8):
|
||||
return False
|
||||
driver_version = torch.utils.collect_env.get_nvidia_driver_version(torch.utils.collect_env.run)
|
||||
if driver_version is None:
|
||||
return False
|
||||
return int(driver_version.split('.')[0]) >= 550
|
||||
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention())
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention())
|
||||
|
|
@ -93,6 +103,8 @@ PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
|
|||
|
||||
PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
|
||||
|
||||
PLATFORM_SUPPORTS_GREEN_CONTEXT: bool = LazyVal(lambda: evaluate_platform_supports_green_context())
|
||||
|
||||
def evaluate_platform_supports_fp8():
|
||||
if torch.cuda.is_available():
|
||||
if torch.version.hip:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user