import ctypes import sys from typing import Any, Optional, Union import torch # The _get_device_index has been moved to torch.utils._get_device_index from torch._utils import _get_device_index as _torch_get_device_index # Load CUDA driver and NVRTC def _get_cuda_library() -> ctypes.CDLL: if sys.platform == "win32": return ctypes.CDLL("nvcuda.dll") else: # Unix-based systems return ctypes.CDLL("libcuda.so.1") # Helper: check CUDA errors def _check_cuda(result: int) -> None: if result == 0: return err_str = ctypes.c_char_p() libcuda = _get_cuda_library() # Get reference to CUDA library libcuda.cuGetErrorString(result, ctypes.byref(err_str)) error_message = ( err_str.value.decode() if err_str.value is not None else "Unknown CUDA error" ) raise RuntimeError(f"CUDA error: {error_message}") def _get_nvrtc_library() -> ctypes.CDLL: # Since PyTorch already loads NVRTC, we can use the system library # which should be compatible with PyTorch's version if sys.platform == "win32": return ctypes.CDLL("nvrtc64_120_0.dll") else: return ctypes.CDLL("libnvrtc.so") def _nvrtc_compile( kernel_source: str, kernel_name: str, compute_capability: Optional[str] = None, header_code: str = "", cuda_include_dirs: Optional[list] = None, nvcc_options: Optional[list] = None, ) -> bytes: """ Compiles a CUDA kernel using NVRTC and returns the PTX code. Args: kernel_source (str): The CUDA kernel source code as a string kernel_name (str): The name of the kernel function to compile compute_capability (str, None): The compute capability to target (e.g., "86"). If None, will detect from current device. header_code (str, optional): Additional header code to prepend to the kernel source cuda_include_dirs (list, None): List of directories containing CUDA headers nvcc_options (list, None): Additional options to pass to NVRTC Returns: str: The compiled PTX code """ # Ensure CUDA is initialized import torch.cuda # Load NVRTC library libnvrtc = _get_nvrtc_library() # NVRTC constants NVRTC_SUCCESS = 0 # Helper: check NVRTC errors def check_nvrtc(result: int) -> None: if result != NVRTC_SUCCESS: err_str = ctypes.c_char_p() libnvrtc.nvrtcGetErrorString(result, ctypes.byref(err_str)) error_message = ( err_str.value.decode() if err_str.value is not None else "Unknown CUDA error" ) raise RuntimeError(f"CUDA error: {error_message}") # Add 'extern "C"' if not already present to ensure C linkage if not kernel_source.strip().startswith('extern "C"'): kernel_source = f'extern "C" {kernel_source}' # Combine header code and kernel source if header_code: full_source = header_code + "\n" + kernel_source else: full_source = kernel_source # Convert source to bytes source_bytes = full_source.encode("utf-8") # Get compute capability if not provided if compute_capability is None: props = torch.cuda.get_device_properties(torch.cuda.current_device()) compute_capability = f"{props.major}{props.minor}" # Prepare compilation options options = [] options.append(f"--gpu-architecture=sm_{compute_capability}".encode()) # Auto-detect and add CUDA include paths from torch.utils.cpp_extension import include_paths cuda_include_paths = include_paths("cuda") for cuda_path in cuda_include_paths: options.append(f"-I{cuda_path}".encode()) # Add custom include directories if cuda_include_dirs: for directory in cuda_include_dirs: options.append(f"-I{directory}".encode()) # Add custom NVCC options if nvcc_options: for option in nvcc_options: options.append(option.encode("utf-8")) # TODO: Should we refactor flags into a common place? from torch.utils.cpp_extension import COMMON_NVCC_FLAGS # Filter out flags not supported by NVRTC nvrtc_compatible_flags = [ flag for flag in COMMON_NVCC_FLAGS if flag != "--expt-relaxed-constexpr" ] options.extend([flag.encode("utf-8") for flag in nvrtc_compatible_flags]) # Convert options to C array num_options = len(options) options_array = (ctypes.c_char_p * num_options)(*options) # Create program prog = ctypes.c_void_p() check_nvrtc( libnvrtc.nvrtcCreateProgram( ctypes.byref(prog), source_bytes, f"{kernel_name}.cu".encode(), 0, None, None, ) ) # Compile program res = libnvrtc.nvrtcCompileProgram(prog, num_options, options_array) # Handle compilation errors if res != NVRTC_SUCCESS: # Get log log_size = ctypes.c_size_t() libnvrtc.nvrtcGetProgramLogSize(prog, ctypes.byref(log_size)) log = ctypes.create_string_buffer(log_size.value) libnvrtc.nvrtcGetProgramLog(prog, log) raise RuntimeError(f"Kernel compilation failed:\n{log.value.decode()}") # Get PTX ptx_size = ctypes.c_size_t() check_nvrtc(libnvrtc.nvrtcGetPTXSize(prog, ctypes.byref(ptx_size))) ptx = ctypes.create_string_buffer(ptx_size.value) check_nvrtc(libnvrtc.nvrtcGetPTX(prog, ptx)) libnvrtc.nvrtcDestroyProgram(ctypes.byref(prog)) return ptx.value class _CudaModule: def __init__(self, module: ctypes.c_void_p) -> None: self._module = module self._kernels: dict[str, _CudaKernel] = {} def __getattr__(self, name: str) -> "_CudaKernel": if name in self._kernels: return self._kernels[name] # Import the CUDA library inside the method from torch.cuda._utils import _get_cuda_library libcuda = _get_cuda_library() func = ctypes.c_void_p() try: _check_cuda( libcuda.cuModuleGetFunction( ctypes.byref(func), self._module, name.encode("utf-8") ) ) kernel = _CudaKernel(func, self._module) self._kernels[name] = kernel return kernel except RuntimeError as err: raise AttributeError(f"No kernel named '{name}' in this module") from err class _CudaKernel: """ Represents a compiled CUDA kernel that can be called with PyTorch tensors. """ def __init__(self, func: ctypes.c_void_p, module: ctypes.c_void_p) -> None: self.func = func self.module = module def __call__( self, grid: tuple[int, int, int] = (1, 1, 1), block: tuple[int, int, int] = (1, 1, 1), args: Optional[list] = None, shared_mem: int = 0, stream: Optional[Any] = None, ) -> None: """ Call the compiled CUDA kernel Args: grid (tuple): Grid dimensions (grid_x, grid_y, grid_z) block (tuple): Block dimensions (block_x, block_y, block_z) args (list): List of arguments to pass to the kernel. PyTorch tensor arguments will be automatically converted to pointers. shared_mem (int): Shared memory size in bytes stream (torch.cuda.Stream): CUDA stream to use. If None, uses current stream. """ import torch libcuda = torch.cuda._utils._get_cuda_library() if not args: args = [] # Process arguments and convert tensors to pointers processed_args: list[ctypes.c_void_p] = [] c_args = [] for arg in args: if isinstance(arg, torch.Tensor): if not arg.is_cuda and not (arg.is_cpu and arg.is_pinned()): raise ValueError( "All tensor arguments must be CUDA tensors or pinned CPU tensors" ) # Get pointer to tensor data ptr = ctypes.c_void_p(arg.data_ptr()) processed_args.append(ptr) c_args.append(ctypes.byref(ptr)) elif isinstance(arg, int): # Convert integers to C int c_int = ctypes.c_int(arg) # Store the C int for reference keeping, not in processed_args c_args.append(ctypes.byref(c_int)) # TODO: Python floats are actually doubles elif isinstance(arg, float): # Convert floats to C float c_float = ctypes.c_float(arg) # Store the C float for reference keeping, not in processed_args c_args.append(ctypes.byref(c_float)) else: raise TypeError(f"Unsupported argument type: {type(arg)}") # Convert to array of void pointers c_args_array = (ctypes.c_void_p * len(c_args))() for i, arg in enumerate(c_args): c_args_array[i] = ctypes.cast(arg, ctypes.c_void_p) # Get the stream if stream is None: # Defer import to avoid circular imports import torch.cuda stream = torch.cuda.current_stream() _check_cuda( libcuda.cuLaunchKernel( self.func, grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem, stream._as_parameter_, c_args_array, None, ) ) def _cuda_load_module( ptx: Union[str, bytes], kernel_names: Optional[list[str]] = None ) -> Union[_CudaModule, dict[str, "_CudaKernel"]]: """ Loads a CUDA module from PTX code and returns a module object that can access kernels. Args: ptx (bytes or str): The PTX code to load kernel_names (list, optional): List of kernel names to extract from the module. If None, will return a module object with __getattr__. Returns: object: If kernel_names is None, returns a module object with __getattr__ to access kernels. If kernel_names is provided, returns a dict mapping kernel names to _CudaKernel objects. """ # Ensure CUDA is initialized import torch.cuda # Load CUDA driver library libcuda = _get_cuda_library() # Convert PTX to bytes if it's a string if isinstance(ptx, str): ptx = ptx.encode("utf-8") # Load PTX module module = ctypes.c_void_p() # Get the current stream without directly importing torch.cuda at module level stream = torch.cuda.current_stream() with stream: _check_cuda(libcuda.cuModuleLoadData(ctypes.byref(module), ptx)) if not kernel_names: return _CudaModule(module) # Return specific kernels kernels = {} for name in kernel_names: func = ctypes.c_void_p() _check_cuda( libcuda.cuModuleGetFunction( ctypes.byref(func), module, name.encode("utf-8") ) ) kernels[name] = _CudaKernel(func, module) return kernels def _get_device_index( device: Any, optional: bool = False, allow_cpu: bool = False ) -> int: r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. If :attr:`device` is a torch.device object, returns the device index if it is a CUDA device. Note that for a CUDA device without a specified index, i.e., ``torch.device('cuda')``, this will return the current default CUDA device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, CPU devices will be accepted and ``-1`` will be returned in this case. If :attr:`device` is a Python integer, it is returned as is. If :attr:`device` is ``None``, this will return the current default CUDA device if :attr:`optional` is ``True``. """ if isinstance(device, int): return device if isinstance(device, str): device = torch.device(device) if isinstance(device, torch.device): if allow_cpu: if device.type not in ["cuda", "cpu"]: raise ValueError(f"Expected a cuda or cpu device, but got: {device}") elif device.type != "cuda": raise ValueError(f"Expected a cuda device, but got: {device}") if not torch.jit.is_scripting(): if isinstance(device, torch.cuda.device): return device.idx return _torch_get_device_index(device, optional, allow_cpu)