[Cutlass] Allow offsets to be passed as arguments to kernel (#159761)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159761
Approved by: https://github.com/henrylhtsang
ghstack dependencies: #159760
This commit is contained in:
Michael Lazos 2025-08-05 11:57:58 -07:00 committed by PyTorch MergeBot
parent 8085edc8f9
commit bdb07a2bc5
6 changed files with 56 additions and 27 deletions

View File

@ -1793,6 +1793,26 @@ class TestCutlassBackend(TestCase):
torch.testing.assert_close(A @ A.t(), compiled(A, A.t()))
@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_cutlass_backend_matmul_nonzero_offset(self):
max_autotune_gemm_backends = "CUTLASS"
M = 129
A = torch.randn(M, M - 1).cuda().half()
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 2,
}
):
compiled = torch.compile(torch.mm)
torch.testing.assert_close(
A[1:, :] @ A[1:, :].t(), compiled(A[1:, :], A[1:, :].t())
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_flexible_layout(self):

View File

@ -392,12 +392,12 @@ return tmp_1, D""",
{}, /* C */
{}, /* compute_0 */
},
{/* ptr_aux */ (float*) ptr_0, /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */
{/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */
{}, /* compute_1 */
},
{/* ptr_aux */ (float*) ptr_1, /* dAux */ {2048, _1{}, _0{}}}, /* F */
{/* ptr_aux */ (float*) (ptr_1 + ptr_1_offset), /* dAux */ {2048, _1{}, _0{}}}, /* F */
},
{/* ptr_col */ (float*) ptr_2, /* null_default */ float(0), /* dCol */ {}}, /* bias */
{/* ptr_col */ (float*) (ptr_2 + ptr_2_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */
{}, /* compute_2 */
{}, /* compute_3 */
{}, /* compute_4 */
@ -444,9 +444,9 @@ def fn(accum, bias):
{ /* thread */
{ /* E */
{}, /* accum */
{/* ptr_aux */ (float*) ptr_0, /* dAux */ {2048, _1{}, _0{}}}, /* E */
{/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* dAux */ {2048, _1{}, _0{}}}, /* E */
},
{/* ptr_col */ (float*) ptr_1, /* null_default */ float(0), /* dCol */ {}}, /* bias */
{/* ptr_col */ (float*) (ptr_1 + ptr_1_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */
{}, /* compute_0 */
}
""",

View File

@ -177,6 +177,9 @@ class CUDAKernel(Kernel):
def get_dynamic_shape_args(self) -> list[Union[Expr, int]]:
return [*self.get_layout_args(), *self.size_args]
def get_offset_args(self) -> list[Expr]:
return [node.get_layout().offset for node in self.named_nodes.values()]
@staticmethod
def find_ld_idx(node: IRNode) -> int:
strides = node.get_stride()
@ -264,6 +267,7 @@ class CUDATemplateKernel(CUDAKernel):
In this case, the `input_reorder` would be [2, 0, 1].
additional_size_args: Additional size arguments for epilogue inputs
"""
# NB: name order matters here, it's used to match up offsets
names = [x.strip() for x in names_str.strip().split(",")]
if len(inputs) + len(outputs) != len(names):
raise RuntimeError(
@ -285,6 +289,7 @@ class CUDATemplateKernel(CUDAKernel):
free_symbols: OrderedSet[Expr] = OrderedSet()
for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
if node is not None:
# NB: named nodes must be populated in the order of names
self.named_nodes[name] = node
self.args.output_buffers[node.get_name()] = name
@ -306,14 +311,17 @@ class CUDATemplateKernel(CUDAKernel):
size_vars.extend(str(s) for s in free_symbols)
self.size_args.extend(free_symbols)
size_args = [f"const int {s}" for s in size_vars]
offset_args = [f"const int {name}_offset" for name in self.named_nodes.keys()]
runtime_arg_decls = ",".join(
[f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info]
)
if runtime_arg_decls:
runtime_arg_decls += ", "
signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {runtime_arg_decls}{self._EXTRA_CPP_ARGS})"
signature = (
f"int {self.kernel_name}({', '.join(arg_defs + size_args + offset_args)},\
{runtime_arg_decls}{self._EXTRA_CPP_ARGS})"
)
self.signature = signature
return signature
@ -346,10 +354,13 @@ class CUDATemplateKernel(CUDAKernel):
_, call_args, _, arg_types = self.args.python_argdefs()
dynamic_shape_args = self.get_dynamic_shape_args()
offset_args = self.get_offset_args()
call_args.extend(dynamic_shape_args) # type: ignore[arg-type]
call_args.extend(offset_args) # type: ignore[arg-type]
for arg in self.runtime_arg_values:
call_args.append(arg)
arg_types.extend("int" for _ in dynamic_shape_args)
call_args.append(str(arg))
arg_types.extend("const int" for _ in dynamic_shape_args)
arg_types.extend("const int" for _ in offset_args)
for arg in self.runtime_arg_info:
arg_types.append(arg.ty)
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
@ -425,15 +436,6 @@ class CUDATemplateKernel(CUDAKernel):
max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i]
return max_valid_offset
def offset(self, node: IRNode) -> str:
"""
Generates code which represents offset of a given node.
"""
if node is None:
return "0"
return str(node.get_layout().offset) # type: ignore[union-attr]
def ptr(self, node: IRNode) -> str:
"""
Generates code which represents pointer of a given node.
@ -444,8 +446,7 @@ class CUDATemplateKernel(CUDAKernel):
arg_name = self.arg_name(node)
if arg_name is None:
return "nullptr"
offset = self.offset(node)
return arg_name if offset == "0" else f"{arg_name} + {offset}"
return f"{arg_name} + {arg_name}_offset"
def size(
self,

View File

@ -43,7 +43,7 @@ class ArgInfo:
class CUDATemplate(KernelTemplate):
index_counter = itertools.count()
# dict of cache key to (code, size_args)
code_cache: dict[str, tuple[str, tuple[int, ...]]] = {}
code_cache: dict[str, tuple[str, tuple[int, ...], tuple[int, ...]]] = {}
cache_clear = staticmethod(code_cache.clear)
def __init__(
@ -113,8 +113,12 @@ class CUDATemplate(KernelTemplate):
key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr)
if key is not None and key in self.code_cache:
code, size_args = self.code_cache[key]
extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs))
code, size_args, offset_args = self.code_cache[key]
extra_args = tuple(
list(size_args)
+ list(offset_args)
+ list(self.get_runtime_arg_values(**kwargs))
)
return code, extra_args
kernel_name = str(Placeholder.KERNEL_NAME)
@ -148,12 +152,15 @@ class CUDATemplate(KernelTemplate):
)
V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :]))
size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args())
offset_args = V.graph.sizevars.size_hints(kernel.get_offset_args())
if key is not None:
self.code_cache[key] = code, size_args
self.code_cache[key] = code, size_args, offset_args
# extra args has runtime params, which shouldn't be cached
extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs))
extra_args = tuple(
list(size_args) + list(offset_args) + self.get_runtime_arg_values(**kwargs)
)
return code, extra_args

View File

@ -255,7 +255,8 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
return f"{{{', '.join([render_stride(x) for x in stride])}}}"
elif issubclass(arg_ty, ctypes.c_void_p):
return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) {arg_renames.new_name(node.get_name())}"
name = arg_renames.new_name(node.get_name())
return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) ({name} + {name}_offset)"
elif (
arg_ty in _CUTLASS_C_DTYPES
): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently

View File

@ -1317,7 +1317,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
f"(({arg_type}){arg_name}_data.get())"
for arg_type, arg_name in zip(arg_types, arg_names)
]
return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, 0, 0, 0, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
def _render_evt(
self,