mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Cutlass] Allow offsets to be passed as arguments to kernel (#159761)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159761 Approved by: https://github.com/henrylhtsang ghstack dependencies: #159760
This commit is contained in:
parent
8085edc8f9
commit
bdb07a2bc5
|
|
@ -1793,6 +1793,26 @@ class TestCutlassBackend(TestCase):
|
|||
|
||||
torch.testing.assert_close(A @ A.t(), compiled(A, A.t()))
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_cutlass_backend_matmul_nonzero_offset(self):
|
||||
max_autotune_gemm_backends = "CUTLASS"
|
||||
|
||||
M = 129
|
||||
A = torch.randn(M, M - 1).cuda().half()
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(torch.mm)
|
||||
torch.testing.assert_close(
|
||||
A[1:, :] @ A[1:, :].t(), compiled(A[1:, :], A[1:, :].t())
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_flexible_layout(self):
|
||||
|
|
|
|||
|
|
@ -392,12 +392,12 @@ return tmp_1, D""",
|
|||
{}, /* C */
|
||||
{}, /* compute_0 */
|
||||
},
|
||||
{/* ptr_aux */ (float*) ptr_0, /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */
|
||||
{/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */
|
||||
{}, /* compute_1 */
|
||||
},
|
||||
{/* ptr_aux */ (float*) ptr_1, /* dAux */ {2048, _1{}, _0{}}}, /* F */
|
||||
{/* ptr_aux */ (float*) (ptr_1 + ptr_1_offset), /* dAux */ {2048, _1{}, _0{}}}, /* F */
|
||||
},
|
||||
{/* ptr_col */ (float*) ptr_2, /* null_default */ float(0), /* dCol */ {}}, /* bias */
|
||||
{/* ptr_col */ (float*) (ptr_2 + ptr_2_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */
|
||||
{}, /* compute_2 */
|
||||
{}, /* compute_3 */
|
||||
{}, /* compute_4 */
|
||||
|
|
@ -444,9 +444,9 @@ def fn(accum, bias):
|
|||
{ /* thread */
|
||||
{ /* E */
|
||||
{}, /* accum */
|
||||
{/* ptr_aux */ (float*) ptr_0, /* dAux */ {2048, _1{}, _0{}}}, /* E */
|
||||
{/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* dAux */ {2048, _1{}, _0{}}}, /* E */
|
||||
},
|
||||
{/* ptr_col */ (float*) ptr_1, /* null_default */ float(0), /* dCol */ {}}, /* bias */
|
||||
{/* ptr_col */ (float*) (ptr_1 + ptr_1_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */
|
||||
{}, /* compute_0 */
|
||||
}
|
||||
""",
|
||||
|
|
|
|||
|
|
@ -177,6 +177,9 @@ class CUDAKernel(Kernel):
|
|||
def get_dynamic_shape_args(self) -> list[Union[Expr, int]]:
|
||||
return [*self.get_layout_args(), *self.size_args]
|
||||
|
||||
def get_offset_args(self) -> list[Expr]:
|
||||
return [node.get_layout().offset for node in self.named_nodes.values()]
|
||||
|
||||
@staticmethod
|
||||
def find_ld_idx(node: IRNode) -> int:
|
||||
strides = node.get_stride()
|
||||
|
|
@ -264,6 +267,7 @@ class CUDATemplateKernel(CUDAKernel):
|
|||
In this case, the `input_reorder` would be [2, 0, 1].
|
||||
additional_size_args: Additional size arguments for epilogue inputs
|
||||
"""
|
||||
# NB: name order matters here, it's used to match up offsets
|
||||
names = [x.strip() for x in names_str.strip().split(",")]
|
||||
if len(inputs) + len(outputs) != len(names):
|
||||
raise RuntimeError(
|
||||
|
|
@ -285,6 +289,7 @@ class CUDATemplateKernel(CUDAKernel):
|
|||
free_symbols: OrderedSet[Expr] = OrderedSet()
|
||||
for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
|
||||
if node is not None:
|
||||
# NB: named nodes must be populated in the order of names
|
||||
self.named_nodes[name] = node
|
||||
self.args.output_buffers[node.get_name()] = name
|
||||
|
||||
|
|
@ -306,14 +311,17 @@ class CUDATemplateKernel(CUDAKernel):
|
|||
size_vars.extend(str(s) for s in free_symbols)
|
||||
self.size_args.extend(free_symbols)
|
||||
size_args = [f"const int {s}" for s in size_vars]
|
||||
|
||||
offset_args = [f"const int {name}_offset" for name in self.named_nodes.keys()]
|
||||
runtime_arg_decls = ",".join(
|
||||
[f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info]
|
||||
)
|
||||
if runtime_arg_decls:
|
||||
runtime_arg_decls += ", "
|
||||
|
||||
signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {runtime_arg_decls}{self._EXTRA_CPP_ARGS})"
|
||||
signature = (
|
||||
f"int {self.kernel_name}({', '.join(arg_defs + size_args + offset_args)},\
|
||||
{runtime_arg_decls}{self._EXTRA_CPP_ARGS})"
|
||||
)
|
||||
self.signature = signature
|
||||
return signature
|
||||
|
||||
|
|
@ -346,10 +354,13 @@ class CUDATemplateKernel(CUDAKernel):
|
|||
_, call_args, _, arg_types = self.args.python_argdefs()
|
||||
|
||||
dynamic_shape_args = self.get_dynamic_shape_args()
|
||||
offset_args = self.get_offset_args()
|
||||
call_args.extend(dynamic_shape_args) # type: ignore[arg-type]
|
||||
call_args.extend(offset_args) # type: ignore[arg-type]
|
||||
for arg in self.runtime_arg_values:
|
||||
call_args.append(arg)
|
||||
arg_types.extend("int" for _ in dynamic_shape_args)
|
||||
call_args.append(str(arg))
|
||||
arg_types.extend("const int" for _ in dynamic_shape_args)
|
||||
arg_types.extend("const int" for _ in offset_args)
|
||||
for arg in self.runtime_arg_info:
|
||||
arg_types.append(arg.ty)
|
||||
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
||||
|
|
@ -425,15 +436,6 @@ class CUDATemplateKernel(CUDAKernel):
|
|||
max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i]
|
||||
return max_valid_offset
|
||||
|
||||
def offset(self, node: IRNode) -> str:
|
||||
"""
|
||||
Generates code which represents offset of a given node.
|
||||
"""
|
||||
|
||||
if node is None:
|
||||
return "0"
|
||||
return str(node.get_layout().offset) # type: ignore[union-attr]
|
||||
|
||||
def ptr(self, node: IRNode) -> str:
|
||||
"""
|
||||
Generates code which represents pointer of a given node.
|
||||
|
|
@ -444,8 +446,7 @@ class CUDATemplateKernel(CUDAKernel):
|
|||
arg_name = self.arg_name(node)
|
||||
if arg_name is None:
|
||||
return "nullptr"
|
||||
offset = self.offset(node)
|
||||
return arg_name if offset == "0" else f"{arg_name} + {offset}"
|
||||
return f"{arg_name} + {arg_name}_offset"
|
||||
|
||||
def size(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class ArgInfo:
|
|||
class CUDATemplate(KernelTemplate):
|
||||
index_counter = itertools.count()
|
||||
# dict of cache key to (code, size_args)
|
||||
code_cache: dict[str, tuple[str, tuple[int, ...]]] = {}
|
||||
code_cache: dict[str, tuple[str, tuple[int, ...], tuple[int, ...]]] = {}
|
||||
cache_clear = staticmethod(code_cache.clear)
|
||||
|
||||
def __init__(
|
||||
|
|
@ -113,8 +113,12 @@ class CUDATemplate(KernelTemplate):
|
|||
key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr)
|
||||
|
||||
if key is not None and key in self.code_cache:
|
||||
code, size_args = self.code_cache[key]
|
||||
extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs))
|
||||
code, size_args, offset_args = self.code_cache[key]
|
||||
extra_args = tuple(
|
||||
list(size_args)
|
||||
+ list(offset_args)
|
||||
+ list(self.get_runtime_arg_values(**kwargs))
|
||||
)
|
||||
return code, extra_args
|
||||
|
||||
kernel_name = str(Placeholder.KERNEL_NAME)
|
||||
|
|
@ -148,12 +152,15 @@ class CUDATemplate(KernelTemplate):
|
|||
)
|
||||
V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :]))
|
||||
size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args())
|
||||
offset_args = V.graph.sizevars.size_hints(kernel.get_offset_args())
|
||||
|
||||
if key is not None:
|
||||
self.code_cache[key] = code, size_args
|
||||
self.code_cache[key] = code, size_args, offset_args
|
||||
|
||||
# extra args has runtime params, which shouldn't be cached
|
||||
extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs))
|
||||
extra_args = tuple(
|
||||
list(size_args) + list(offset_args) + self.get_runtime_arg_values(**kwargs)
|
||||
)
|
||||
|
||||
return code, extra_args
|
||||
|
||||
|
|
|
|||
|
|
@ -255,7 +255,8 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
|||
return f"{{{', '.join([render_stride(x) for x in stride])}}}"
|
||||
|
||||
elif issubclass(arg_ty, ctypes.c_void_p):
|
||||
return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) {arg_renames.new_name(node.get_name())}"
|
||||
name = arg_renames.new_name(node.get_name())
|
||||
return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) ({name} + {name}_offset)"
|
||||
elif (
|
||||
arg_ty in _CUTLASS_C_DTYPES
|
||||
): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently
|
||||
|
|
|
|||
|
|
@ -1317,7 +1317,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
f"(({arg_type}){arg_name}_data.get())"
|
||||
for arg_type, arg_name in zip(arg_types, arg_names)
|
||||
]
|
||||
return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
|
||||
return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, 0, 0, 0, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
|
||||
|
||||
def _render_evt(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user