[StaticCudaLauncher] Support sharedMemBytes > 48KB (#149657)

Triton does some special handling when requesting more than 48 KB of shared memory: specifically it queries the device for maximum device memory, then sets the maximum amount of dynamic memory to be the difference between static and dynamic memory.

See corresponding implementation in triton land here:
https://github.com/triton-lang/triton/blob/main/third_party/nvidia/backend/driver.c#L128-L143

Test plan:
- New unit test requesting more than 48 KB of memory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149657
Approved by: https://github.com/jansel
This commit is contained in:
James Wu 2025-03-27 05:09:28 -07:00 committed by PyTorch MergeBot
parent 85e4e51a7d
commit 14f0cd7630
7 changed files with 101 additions and 20 deletions

View File

@ -158,6 +158,8 @@ NVRTC_STUB3(nvrtcGetLoweredName, nvrtcProgram, const char *, const char **)
CUDA_STUB2(cuModuleLoad, CUmodule*, const char*)
CUDA_STUB2(cuModuleLoadData, CUmodule *, const void *)
CUDA_STUB2(cuFuncSetCacheConfig, CUfunction, CUfunc_cache_enum)
CUDA_STUB3(cuDeviceGetAttribute, int*, CUdevice_attribute_enum, CUdevice)
CUDA_STUB3(cuModuleGetFunction, CUfunction *, CUmodule, const char *)
CUDA_STUB4(cuOccupancyMaxActiveBlocksPerMultiprocessor, int *, CUfunction, int, size_t)
CUDA_STUB2(cuGetErrorString, CUresult, const char **)

View File

@ -62,6 +62,9 @@ namespace at::cuda {
_(cuFuncSetAttribute) \
_(cuFuncGetAttribute) \
_(cuPointerGetAttribute) \
_(cuFuncSetCacheConfig) \
_(cuDeviceGetAttribute) \
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
#define AT_FORALL_NVRTC_EXTENDED(_) \

View File

@ -54,7 +54,8 @@ class TestStaticCudaLauncher(TestCase):
cubin_file = self.write_cubin_to_tmp(compiled_kernel)
compiled_kernel._cubin_path = cubin_file
result = StaticallyLaunchedCudaKernel(compiled_kernel)
result.load_kernel()
device_interface = get_interface_for_device("cuda")
result.load_kernel(device_interface.current_device())
return result
@skipIfRocm
@ -279,6 +280,50 @@ class TestStaticCudaLauncher(TestCase):
stream = device_interface.get_raw_stream(device_interface.current_device())
launcher.run(1, 1, 1, stream)
@skipIfRocm
def test_high_shared_mem(self):
@triton.jit
def simple_kernel(arg0, arg1):
x = tl.load(arg0)
y = arg1
tl.store(arg0, x + y)
arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
arg1 = 5
args = (arg0, arg1)
compiled_kernel = simple_kernel[(1,)](*args)
# Allocate 50 KB of memory
compiled_kernel.shared = 50000
launcher = self._make_launcher(compiled_kernel)
self.assertEqual(arg0, torch.tensor([5], dtype=torch.int32, device="cuda"))
self.assertEqual(launcher.arg_tys, "Oi")
new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
device_interface = get_interface_for_device("cuda")
stream = device_interface.get_raw_stream(device_interface.current_device())
launcher.slow_launch_kernel = True
launcher.run(1, 1, 1, stream, new_arg0, arg1)
self.assertEqual(new_arg0, arg0)
@skipIfRocm
def test_too_high_shared_mem(self):
@triton.jit
def simple_kernel(arg0, arg1):
x = tl.load(arg0)
y = arg1
tl.store(arg0, x + y)
arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
arg1 = 5
args = (arg0, arg1)
compiled_kernel = simple_kernel[(1,)](*args)
# Allocate too much shared memory
compiled_kernel.shared = 99999999
self.assertRaisesRegex(
RuntimeError,
"out of resource: simple_kernel",
lambda: self._make_launcher(compiled_kernel),
)
@skipIfRocm
def test_kernel_empty_tensor(self):
# Triton kernel generated by torch.compile of the following:

View File

@ -2556,6 +2556,7 @@ class _StaticCudaLauncher:
cubin_file: str,
func_name: str,
shared_mem_bytes: _int,
device: _int,
) -> Tuple[_int, _int, _int]:
...

View File

@ -5,10 +5,6 @@ from typing_extensions import Unpack
from .triton_compat import ASTSource, CompiledKernel
MAX_SHARED_MEMORY = 49152
MAX_ARGS = 120
class StaticallyLaunchedCudaKernel:
"""
Parses the metadata of a CompiledKernel from Triton into a structure that can
@ -59,13 +55,6 @@ class StaticallyLaunchedCudaKernel:
self.shared = (
kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared
)
# When shared memory > 48 KB, triton allocates CUDA memory via both static and dynamic
# memory allocation, which gets really complicated. We'll handle it later.
# See triton/third-party/nvidia/driver.c in loadBinary
if self.shared > MAX_SHARED_MEMORY:
raise NotImplementedError(
"Shared memory size > 48KB requires special triton handling"
)
# Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel.
# Inductor never uses this field or enables it, but we still have to pass
@ -93,14 +82,14 @@ class StaticallyLaunchedCudaKernel:
"Static cuda launcher only supports num_ctas == 1"
)
def load_kernel(self) -> None:
def load_kernel(self, device: int) -> None:
from torch._C import _StaticCudaLauncher
assert hasattr(self, "cubin_path")
if self.function is not None:
return
(self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel(
self.cubin_path, self.name, self.shared
self.cubin_path, self.name, self.shared, device
)
@staticmethod

View File

@ -1200,7 +1200,7 @@ class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]):
def make_launcher(self) -> LauncherType:
# Load the binary on the parent
self.kernel.load_kernel()
self.kernel.load_kernel(self.compile_meta.get("device", 0))
scope = {
"runner": self.kernel.run,
}

View File

@ -83,25 +83,59 @@ CUdeviceptr getPointer(PyObject* obj) {
return dev_ptr;
}
#define SHARED_MEM_STATIC_MAX 49152 // 48 KB
CUfunction loadKernel(
std::string filePath,
const std::string& funcName,
uint32_t sharedMemBytes,
CUdevice device,
const std::optional<std::string>& cubinDir = std::nullopt) {
if (cubinDir) {
std::filesystem::path p1{*cubinDir};
std::filesystem::path p2{filePath};
filePath = (p1 / p2.filename()).string();
}
CUmodule mod = nullptr;
CUfunction func = nullptr;
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoad(&mod, filePath.c_str()));
AT_CUDA_DRIVER_CHECK(
nvrtc().cuModuleGetFunction(&func, mod, funcName.c_str()));
if (sharedMemBytes > 0) {
int shared_optin = 0;
AT_CUDA_DRIVER_CHECK(nvrtc().cuDeviceGetAttribute(
&shared_optin,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device));
// Shared memory logic from triton/third-party/nvidia/backend/driver.c
// If we're using more than 48 KB of shared memory, and we have
// access to more than 48 KB of shared memory on the device,
// we set maximum dynamic shared memory to the difference between
// the static shared memory and total max shared memory allowed on the device.
// This prevents us from setting shared memory above the maximum
TORCH_CHECK(
sharedMemBytes < static_cast<uint32_t>(shared_optin),
"out of resource: ",
funcName,
" Required: ",
sharedMemBytes,
" Hardware limit:",
shared_optin,
" Reducing block sizes or `num_stages` may help.");
if (sharedMemBytes > SHARED_MEM_STATIC_MAX &&
shared_optin > SHARED_MEM_STATIC_MAX) {
AT_CUDA_DRIVER_CHECK(
nvrtc().cuFuncSetCacheConfig(func, CU_FUNC_CACHE_PREFER_SHARED));
int shared_total = 0, shared_static = 0;
AT_CUDA_DRIVER_CHECK(nvrtc().cuDeviceGetAttribute(
&shared_total,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
device));
AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncGetAttribute(
&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func));
AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncSetAttribute(
func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, sharedMemBytes));
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
}
return func;
}
@ -220,16 +254,20 @@ void parseKernelArgs(
sharedMemBytes)
*/
PyObject* load_kernel(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
const char* filePath = nullptr;
const char* funcName = nullptr;
int sharedMemBytes = 0;
int n_regs = 0;
int n_spills = 0;
if (!PyArg_ParseTuple(args, "ssi", &filePath, &funcName, &sharedMemBytes)) {
int device_ptr = 0;
if (!PyArg_ParseTuple(
args, "ssii", &filePath, &funcName, &sharedMemBytes, &device_ptr)) {
return nullptr;
}
CUdevice device = static_cast<CUdevice>(device_ptr); // NOLINT
CUfunction func = nullptr;
func = loadKernel(filePath, funcName, sharedMemBytes);
func = loadKernel(filePath, funcName, sharedMemBytes, device);
// Taken from triton/nvidia/backend/driver.c
AT_CUDA_DRIVER_CHECK(
nvrtc().cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, func));
@ -239,6 +277,7 @@ PyObject* load_kernel(PyObject* self, PyObject* args) {
// Return a tuple of CUFunction, n_regs, n_spills
return Py_BuildValue(
"(Kii)", reinterpret_cast<uint64_t>(func), n_regs, n_spills);
END_HANDLE_TH_ERRORS
}
PyObject* launch_kernel_inner(
@ -317,6 +356,7 @@ PyObject* launch_kernel_slow(
*
*/
PyObject* launch_kernel(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
// Pointer to CUfunction generated by load_kernel()
uint64_t func_ptr = 0;
int gridX = 0, gridY = 0, gridZ = 0, numWarps = 0, sharedMemBytes = 0;
@ -382,6 +422,7 @@ PyObject* launch_kernel(PyObject* self, PyObject* args) {
varArgs,
cudaStream);
}
END_HANDLE_TH_ERRORS
}
std::array<PyMethodDef, 2> StaticCudaLauncherMethods = {