pytorch/torch/cuda/green_contexts.py
Eddie Yan e64a814ae7 [CUDA] Add experimental green context support for SM carveout (#159104)
Low-level PyTorch APIs should be usable/stable enough at this point but we might move the underlying driver API usage a bit from here...

Built on top of @drisspg 's branch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159104
Approved by: https://github.com/ngimel, https://github.com/malfet, https://github.com/kwen2501

Co-authored-by: drisspg <drisspguessous@gmail.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-10-22 21:38:52 +00:00

43 lines
1.4 KiB
Python

import torch
_GreenContext = object
SUPPORTED = False
if hasattr(torch._C, "_CUDAGreenContext"):
_GreenContext = torch._C._CUDAGreenContext # type: ignore[misc]
SUPPORTED = True
# Python shim helps Sphinx process docstrings more reliably.
class GreenContext(_GreenContext):
r"""Wrapper around a CUDA green context.
.. warning::
This API is in beta and may change in future releases.
"""
@staticmethod
def create(num_sms: int, device_id: int = 0) -> _GreenContext:
r"""Create a CUDA green context.
Arguments:
num_sms (int): The number of SMs to use in the green context.
device_id (int, optional): The device index of green context.
"""
if not SUPPORTED:
raise RuntimeError("PyTorch was not built with Green Context support!")
return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined]
# Note that these functions are bypassed by we define them here
# for Sphinx documentation purposes
def set_context(self) -> None:
r"""Make the green context the current context."""
return super().set_context() # type: ignore[misc]
def pop_context(self) -> None:
r"""Assuming the green context is the current context, pop it from the
context stack and restore the previous context.
"""
return super().pop_context() # type: ignore[misc]