[Cutlass] Fixes for e2e compilation in arg rendering (#151405)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151405
Approved by: https://github.com/eellison
ghstack dependencies: #152305, #152306, #150905
This commit is contained in:
Michael Lazos 2025-04-29 10:01:20 -07:00 committed by PyTorch MergeBot
parent a0ce5ce6e4
commit a1f6d85b36
3 changed files with 63 additions and 44 deletions

View File

@ -301,38 +301,37 @@ return D, tmp_2""",
epilogue_functor, _create_mock_buffer_name_map(EXAMPLE_TENSORS)
),
"""\
{{
{ /* thread */
{ /* thread */
{ /* F */
{ /* compute_1 */
{ /* compute_0 */
{}, /* accum */
{}, /* C */
{}, /* compute_0 */
},
{/* ptr_aux */ aux.get(), /* null_default */ float, /* dAux */ {2048, _1{}, _0{}}}, /* aux */
{}, /* compute_1 */
{ /* compute_1 */
{ /* compute_0 */
{}, /* accum */
{}, /* C */
{}, /* compute_0 */
},
{/* ptr_aux */ F.get(), /* dAux */ {2048, _1{}, _0{}}}, /* F */
{/* ptr_aux */ (float*) aux, /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */
{}, /* compute_1 */
},
{/* ptr_aux */ (float*) F, /* dAux */ {2048, _1{}, _0{}}}, /* F */
},
{/* ptr_col */ bias.get(), /* null_default */ float, /* dCol */ {}}, /* bias */
{/* ptr_col */ (float*) bias, /* null_default */ float(0), /* dCol */ {}}, /* bias */
{}, /* compute_2 */
{}, /* compute_3 */
{}, /* compute_4 */
},
}};
}
""",
)
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_evt_codegen(self):
_, code = trace(
_, _, code = trace(
BIAS_CODE,
EXAMPLE_TENSORS,
DataType.f32,
DataType.f32,
MockTileDescription(),
EpilogueScheduleType.ScheduleAuto,
_create_mock_buffer_name_map(EXAMPLE_TENSORS),
)
self.assertExpectedInline(
code,

View File

@ -115,8 +115,9 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
output_type: DataType,
tile_description: TileDescription,
epilogue_schedule: EpilogueScheduleType,
name_to_buffer: dict[str, Buffer],
**kwargs: dict[str, Any],
) -> tuple[str, str]:
) -> tuple[str, str, str]:
cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
assert cuda_arch >= 90, "Only SM90+ is supported for EVT"
epilogue_functor = _trace(fn_src, example_tensors, **kwargs)
@ -129,8 +130,9 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
output_type,
fusion_callbacks,
)
return collective_epilogue.emit()
evt_name, evt_code = collective_epilogue.emit()
evt_args = _render_argument_type(epilogue_functor, name_to_buffer)
return evt_name, evt_args, evt_code
# Based off of
# https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117
@ -167,33 +169,42 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
)
buffer = IndentedBuffer()
with buffer.set_tabwidth(2):
def render_argument_type(name: str, t: CutlassArgType) -> None:
if issubclass(t, ctypes.c_byte):
buffer.writeline(f"{{}}, /* {name} */")
else:
fields = [
(fname, _get_arg_from_node(ty, name_to_buffer[name]))
for fname, ty in t._fields_
]
field_strs = [f"/* {fname} */ {str(field)}" for fname, field in fields]
buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */")
def render_argument_type(name: str, t: CutlassArgType) -> None:
if issubclass(t, ctypes.c_byte):
buffer.writeline(f"{{}}, /* {name} */")
else:
fields = [
(fname, _get_arg_from_node(ty, name_to_buffer[name]))
for fname, ty in t._fields_
]
field_strs = [
f"/* {fname} */ {str(field)}" for fname, field in fields
]
buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */")
def render_thread_type(name: str, t: CutlassArgType) -> None:
if is_nested_visitor_type(t):
buffer.writeline(f"{{ /* {name} */")
with buffer.indent():
for name, inner_t in t._fields_:
render_thread_type(name, inner_t)
buffer.writeline("},")
else:
render_argument_type(name, t)
def render_thread_type(name: str, t: CutlassArgType) -> None:
if is_nested_visitor_type(t):
buffer.writeline(f"{{ /* {name} */")
with buffer.indent():
for name, inner_t in t._fields_:
render_thread_type(name, inner_t)
buffer.writeline("},")
else:
render_argument_type(name, t)
buffer.writeline("{{")
with buffer.indent():
render_thread_type("thread", epilogue_thread_type)
buffer.writeline("}};")
# unroll the recursion once to address special case formatting
# namely, no ending comma and no indentation for the outermost thread type
buffer.writeline("{ /* thread */")
with buffer.indent(3):
if is_nested_visitor_type(epilogue_thread_type):
with buffer.indent():
for name, inner_t in epilogue_thread_type._fields_:
render_thread_type(name, inner_t)
else:
render_argument_type("thread", epilogue_thread_type)
buffer.writeline("}")
return buffer.getvalue()
@ -225,11 +236,11 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
return f"{{{', '.join([render_stride(x) for x in stride])}}}"
elif issubclass(arg_ty, ctypes.c_void_p):
return f"{node.get_name()}.get()"
return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) {node.get_name()}"
elif (
arg_ty in _CUTLASS_C_DTYPES
): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently
return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
return f"{CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}(0)"
elif issubclass(arg_ty, EmptyByte):
return "{}"

View File

@ -1165,6 +1165,15 @@ class IndentedBuffer:
self._lines: list[Union[DeferredLineBase, LineContext, str]] = []
self._indent = initial_indent
@contextlib.contextmanager
def set_tabwidth(self, tabwidth: int) -> Iterator[None]:
prev = self.tabwidth
try:
self.tabwidth = tabwidth
yield
finally:
self.tabwidth = prev
def getvaluewithlinemap(self) -> ValueWithLineMap:
buf = StringIO()
p = 1