import ctypes import sys from typing import Any, Optional, Union import torch # The _get_device_index has been moved to torch.utils._get_device_index from torch._utils import _get_device_index as _torch_get_device_index def _get_hip_runtime_library() -> ctypes.CDLL: if sys.platform == "win32": lib = ctypes.CDLL(f"amdhip64_{torch.version.hip[0]}.dll") else: # Unix-based systems lib = ctypes.CDLL("libamdhip64.so") lib.cuGetErrorString = lib.hipGetErrorString # type: ignore[attr-defined] lib.cuModuleLoadData = lib.hipModuleLoadData # type: ignore[attr-defined] lib.cuModuleGetFunction = lib.hipModuleGetFunction # type: ignore[attr-defined] lib.cuLaunchKernel = lib.hipModuleLaunchKernel # type: ignore[attr-defined] lib.cuFuncSetAttribute = lib.hipFuncSetAttribute # type: ignore[attr-defined] return lib def _get_cuda_runtime_library() -> ctypes.CDLL: if sys.platform == "win32": return ctypes.CDLL("nvcuda.dll") else: # Unix-based systems return ctypes.CDLL("libcuda.so.1") # Load GPU driver runtime def _get_gpu_runtime_library() -> ctypes.CDLL: if torch.version.hip: return _get_hip_runtime_library() else: return _get_cuda_runtime_library() # Helper: check CUDA errors def _check_cuda(result: int) -> None: if result == 0: return err_str = ctypes.c_char_p() libcuda = _get_gpu_runtime_library() # Get reference to CUDA library libcuda.cuGetErrorString(result, ctypes.byref(err_str)) error_message = ( err_str.value.decode() if err_str.value is not None else "Unknown CUDA error" ) raise RuntimeError(f"CUDA error: {error_message}") def _get_hiprtc_library() -> ctypes.CDLL: if sys.platform == "win32": version_str = "".join(["0", torch.version.hip[0], "0", torch.version.hip[2]]) lib = ctypes.CDLL(f"hiprtc{version_str}.dll") else: lib = ctypes.CDLL("libhiprtc.so") # Provide aliases for HIP RTC functions to match NVRTC API lib.nvrtcGetErrorString = lib.hiprtcGetErrorString # type: ignore[attr-defined] lib.nvrtcCreateProgram = lib.hiprtcCreateProgram # type: ignore[attr-defined] lib.nvrtcDestroyProgram = lib.hiprtcDestroyProgram # type: ignore[attr-defined] lib.nvrtcCompileProgram = lib.hiprtcCompileProgram # type: ignore[attr-defined] lib.nvrtcGetPTXSize = lib.hiprtcGetCodeSize # type: ignore[attr-defined] lib.nvrtcGetPTX = lib.hiprtcGetCode # type: ignore[attr-defined] lib.nvrtcGetProgramLogSize = lib.hiprtcGetProgramLogSize # type: ignore[attr-defined] lib.nvrtcGetProgramLog = lib.hiprtcGetProgramLog # type: ignore[attr-defined] lib.nvrtcAddNameExpression = lib.hiprtcAddNameExpression # type: ignore[attr-defined] lib.nvrtcGetLoweredName = lib.hiprtcGetLoweredName # type: ignore[attr-defined] return lib def _get_nvrtc_library() -> ctypes.CDLL: major_version = int(torch.version.cuda.split(".")[0]) # type: ignore[union-attr] if sys.platform == "win32": nvrtc_libs = [ f"nvrtc64_{major_version}0_0.dll", ] else: nvrtc_libs = [ f"libnvrtc.so.{major_version}", "libnvrtc.so", # Fallback to unversioned ] for lib_name in nvrtc_libs: try: return ctypes.CDLL(lib_name) except OSError: continue raise OSError("Could not find any NVRTC library") def _get_gpu_rtc_library() -> ctypes.CDLL: # Since PyTorch already loads the GPU RTC library, we can use the system library # which should be compatible with PyTorch's version if torch.version.hip: return _get_hiprtc_library() else: return _get_nvrtc_library() def _get_gpu_rtc_compatible_flags() -> list[str]: """ Get HIPCC/NVCC flags that are compatible with NVRTC compilation. Returns: List of HIPCC/NVCC flags that can be safely used with NVRTC. """ from torch.utils.cpp_extension import COMMON_HIPCC_FLAGS, COMMON_NVCC_FLAGS nvrtc_unsupported_flags = { "--expt-relaxed-constexpr", } # Filter out unsupported flags compatible_flags = [ flag for flag in COMMON_NVCC_FLAGS if flag not in nvrtc_unsupported_flags ] if torch.version.hip: compatible_flags.extend(COMMON_HIPCC_FLAGS) return compatible_flags def _nvrtc_compile( kernel_source: str, kernel_name: str, compute_capability: Optional[str] = None, cuda_include_dirs: Optional[list] = None, nvcc_options: Optional[list] = None, auto_pch: bool = False, ) -> tuple[bytes, str]: """ Compiles a CUDA kernel using NVRTC and returns the PTX code. Args: kernel_source (str): The CUDA kernel source code as a string kernel_name (str): The name of the kernel function to compile compute_capability (str, None): The compute capability to target (e.g., "86"). If None, will detect from current device. cuda_include_dirs (list, None): List of directories containing CUDA headers nvcc_options (list, None): Additional options to pass to NVRTC auto_pch (bool): Enable automatic precompiled headers (CUDA 12.8+) Returns: Tuple[bytes, str]: The compiled PTX code and mangled kernel name """ # Ensure CUDA is initialized import torch.cuda # Load NVRTC library libnvrtc = _get_gpu_rtc_library() # NVRTC constants NVRTC_SUCCESS = 0 # Helper: check NVRTC errors def check_nvrtc(result: int) -> None: if result != NVRTC_SUCCESS: err_str = ctypes.c_char_p() libnvrtc.nvrtcGetErrorString(result, ctypes.byref(err_str)) error_message = ( err_str.value.decode() if err_str.value is not None else "Unknown CUDA error" ) raise RuntimeError(f"CUDA error: {error_message}") # Convert source to bytes source_bytes = kernel_source.encode("utf-8") # Get compute capability if not provided if compute_capability is None: props = torch.cuda.get_device_properties(torch.cuda.current_device()) if torch.version.hip: compute_capability = f"{props.gcnArchName}" else: compute_capability = f"{props.major}{props.minor}" # Prepare compilation options options = [] if torch.version.hip: options.append(f"--offload-arch={compute_capability}".encode()) else: options.append(f"--gpu-architecture=sm_{compute_capability}".encode()) # Auto-detect and add CUDA include paths from torch.utils.cpp_extension import include_paths cuda_include_paths = include_paths("cuda") for cuda_path in cuda_include_paths: options.append(f"-I{cuda_path}".encode()) # Add custom include directories if cuda_include_dirs: for directory in cuda_include_dirs: options.append(f"-I{directory}".encode()) # Enable automatic precompiled headers (CUDA 12.8+) if auto_pch: assert str(torch.version.cuda) >= "12.8", "PCH requires CUDA 12.8+" if nvcc_options is None: nvcc_options = [] nvcc_options.append("--pch") # Add custom NVCC options if nvcc_options: for option in nvcc_options: options.append(option.encode("utf-8")) nvrtc_compatible_flags = _get_gpu_rtc_compatible_flags() options.extend([flag.encode("utf-8") for flag in nvrtc_compatible_flags]) # Convert options to C array num_options = len(options) options_array = (ctypes.c_char_p * num_options)(*options) # Create program prog = ctypes.c_void_p() check_nvrtc( libnvrtc.nvrtcCreateProgram( ctypes.byref(prog), source_bytes, f"{kernel_name}.cu".encode(), 0, None, None, ) ) # Add kernel name, which can be a template expression c_kernel_name = kernel_name.encode("utf-8") check_nvrtc(libnvrtc.nvrtcAddNameExpression(prog, c_kernel_name)) # Compile program res = libnvrtc.nvrtcCompileProgram(prog, num_options, options_array) # Handle compilation errors if res != NVRTC_SUCCESS: # Get log log_size = ctypes.c_size_t() libnvrtc.nvrtcGetProgramLogSize(prog, ctypes.byref(log_size)) log = ctypes.create_string_buffer(log_size.value) libnvrtc.nvrtcGetProgramLog(prog, log) raise RuntimeError(f"Kernel compilation failed:\n{log.value.decode()}") # Get PTX ptx_size = ctypes.c_size_t() check_nvrtc(libnvrtc.nvrtcGetPTXSize(prog, ctypes.byref(ptx_size))) ptx = ctypes.create_string_buffer(ptx_size.value) check_nvrtc(libnvrtc.nvrtcGetPTX(prog, ptx)) # Get mangled name c_mangled_name = ctypes.c_char_p() check_nvrtc( libnvrtc.nvrtcGetLoweredName(prog, c_kernel_name, ctypes.byref(c_mangled_name)) ) if c_mangled_name.value is not None: mangled_name = c_mangled_name.value.decode() # make a copy else: mangled_name = "" libnvrtc.nvrtcDestroyProgram(ctypes.byref(prog)) # For HIP, hipRTC generates raw CO binaries instead of PTX, # and for some reason, ".value" causes the string to be truncated, # likely due to the presence of '\0' in the string. So we use .raw instead. ptx_bytes = ptx.raw if torch.version.hip else ptx.value return ptx_bytes, mangled_name class _CudaModule: def __init__(self, module: ctypes.c_void_p) -> None: self._module = module self._kernels: dict[str, _CudaKernel] = {} def __getattr__(self, name: str) -> "_CudaKernel": if name in self._kernels: return self._kernels[name] # Import the CUDA library inside the method from torch.cuda._utils import _get_gpu_runtime_library libcuda = _get_gpu_runtime_library() func = ctypes.c_void_p() try: _check_cuda( libcuda.cuModuleGetFunction( ctypes.byref(func), self._module, name.encode("utf-8") ) ) kernel = _CudaKernel(func, self._module) self._kernels[name] = kernel return kernel except RuntimeError as err: raise AttributeError(f"No kernel named '{name}' in this module") from err class _CudaKernel: """ Represents a compiled CUDA kernel that can be called with PyTorch tensors. """ def __init__(self, func: ctypes.c_void_p, module: ctypes.c_void_p) -> None: self.func = func self.module = module self._max_shared_mem_bytes = 0 def __call__( self, grid: tuple[int, int, int] = (1, 1, 1), block: tuple[int, int, int] = (1, 1, 1), args: Optional[list] = None, shared_mem: int = 0, stream: Optional[Any] = None, ) -> None: """ Call the compiled CUDA kernel Args: grid (tuple): Grid dimensions (grid_x, grid_y, grid_z) block (tuple): Block dimensions (block_x, block_y, block_z) args (list): List of arguments to pass to the kernel. PyTorch tensor arguments will be automatically converted to pointers. shared_mem (int): Shared memory size in bytes stream (torch.cuda.Stream): CUDA stream to use. If None, uses current stream. """ import torch libcuda = torch.cuda._utils._get_gpu_runtime_library() if not args: args = [] # Process arguments and convert tensors to pointers processed_args: list[ctypes.c_void_p] = [] c_args = [] for arg in args: if isinstance(arg, torch.Tensor): if not arg.is_cuda and not (arg.is_cpu and arg.is_pinned()): raise ValueError( "All tensor arguments must be CUDA tensors or pinned CPU tensors" ) # Get pointer to tensor data ptr = ctypes.c_void_p(arg.data_ptr()) processed_args.append(ptr) c_args.append(ctypes.byref(ptr)) elif isinstance(arg, int): # Convert integers to C int c_int = ctypes.c_int(arg) # Store the C int for reference keeping, not in processed_args c_args.append(ctypes.byref(c_int)) elif isinstance(arg, float): # Python floats are doubles - use double by default c_double = ctypes.c_double(arg) # Store the C double for reference keeping, not in processed_args c_args.append(ctypes.byref(c_double)) else: raise TypeError(f"Unsupported argument type: {type(arg)}") # Convert to array of void pointers c_args_array = (ctypes.c_void_p * len(c_args))() for i, arg in enumerate(c_args): c_args_array[i] = ctypes.cast(arg, ctypes.c_void_p) # Get the stream if stream is None: # Defer import to avoid circular imports import torch.cuda stream = torch.cuda.current_stream() # Check if kernel requires large shared memory but hasn't been configured if shared_mem >= 48 * 1024 and ( self._max_shared_mem_bytes == 0 or shared_mem > self._max_shared_mem_bytes ): configured_msg = ( "not configured" if self._max_shared_mem_bytes == 0 else f"only {self._max_shared_mem_bytes} bytes configured" ) raise RuntimeError( f"Kernel requires {shared_mem} bytes of shared memory (>= 48KB), " f"but {configured_msg}. " "Call kernel.set_shared_memory_config(shared_mem) after compilation " "and before launching the kernel." ) _check_cuda( libcuda.cuLaunchKernel( self.func, grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem, stream._as_parameter_, c_args_array, None, ) ) def set_shared_memory_config(self, shared_mem_bytes: int) -> None: if shared_mem_bytes < 48 * 1024: # No configuration needed for <= 48KB, just update the value self._max_shared_mem_bytes = shared_mem_bytes return libcuda = _get_gpu_runtime_library() # Get device properties to validate against limits device_props = torch.cuda.get_device_properties() # HIP doesn't have shared_memory_per_block_optin in device properties, so we hard-code it here if torch.version.hip: # navi, CDNA1-CDNA3 allows a max of 64KB shared memory # CDNA4 allows a max of 160KB shared memory max_shared_mem = ( 65536 if device_props.gcnArchName not in ["gfx950"] else 160 * 1024 ) else: max_shared_mem = getattr( device_props, "shared_memory_per_block_optin", 49152 ) if shared_mem_bytes > max_shared_mem: raise RuntimeError( f"Requested shared memory ({shared_mem_bytes} bytes) exceeds " f"device limit ({max_shared_mem} bytes). " "Consider reducing block size or shared memory usage." ) # Set the function attribute once # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html cudaFuncAttributeMaxDynamicSharedMemorySize = 8 _check_cuda( libcuda.cuFuncSetAttribute( self.func, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_bytes, ) ) self._max_shared_mem_bytes = shared_mem_bytes def _cuda_load_module( ptx: Union[str, bytes], kernel_names: Optional[list[str]] = None ) -> Union[_CudaModule, dict[str, "_CudaKernel"]]: """ Loads a CUDA module from PTX code and returns a module object that can access kernels. Args: ptx (bytes or str): The PTX code to load kernel_names (list, optional): List of kernel names to extract from the module. If None, will return a module object with __getattr__. Returns: object: If kernel_names is None, returns a module object with __getattr__ to access kernels. If kernel_names is provided, returns a dict mapping kernel names to _CudaKernel objects. """ # Ensure CUDA is initialized import torch.cuda # Load CUDA driver library libcuda = _get_gpu_runtime_library() # Convert PTX to bytes if it's a string if isinstance(ptx, str): ptx = ptx.encode("utf-8") # Load PTX module module = ctypes.c_void_p() # Get the current stream without directly importing torch.cuda at module level stream = torch.cuda.current_stream() with stream: _check_cuda(libcuda.cuModuleLoadData(ctypes.byref(module), ptx)) if not kernel_names: return _CudaModule(module) # Return specific kernels kernels = {} for name in kernel_names: func = ctypes.c_void_p() _check_cuda( libcuda.cuModuleGetFunction( ctypes.byref(func), module, name.encode("utf-8") ) ) kernels[name] = _CudaKernel(func, module) return kernels def _get_device_index( device: Any, optional: bool = False, allow_cpu: bool = False ) -> int: r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. If :attr:`device` is a torch.device object, returns the device index if it is a CUDA device. Note that for a CUDA device without a specified index, i.e., ``torch.device('cuda')``, this will return the current default CUDA device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, CPU devices will be accepted and ``-1`` will be returned in this case. If :attr:`device` is a Python integer, it is returned as is. If :attr:`device` is ``None``, this will return the current default CUDA device if :attr:`optional` is ``True``. """ if isinstance(device, int): return device if isinstance(device, str): device = torch.device(device) if isinstance(device, torch.device): if allow_cpu: if device.type not in ["cuda", "cpu"]: raise ValueError(f"Expected a cuda or cpu device, but got: {device}") elif device.type != "cuda": raise ValueError(f"Expected a cuda device, but got: {device}") if not torch.jit.is_scripting(): if isinstance(device, torch.cuda.device): return device.idx return _torch_get_device_index(device, optional, allow_cpu)