mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Cutlass] Fixes for e2e compilation in arg rendering (#151405)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151405 Approved by: https://github.com/eellison ghstack dependencies: #152305, #152306, #150905
This commit is contained in:
parent
a0ce5ce6e4
commit
a1f6d85b36
|
|
@ -301,38 +301,37 @@ return D, tmp_2""",
|
|||
epilogue_functor, _create_mock_buffer_name_map(EXAMPLE_TENSORS)
|
||||
),
|
||||
"""\
|
||||
{{
|
||||
{ /* thread */
|
||||
{ /* thread */
|
||||
{ /* F */
|
||||
{ /* compute_1 */
|
||||
{ /* compute_0 */
|
||||
{}, /* accum */
|
||||
{}, /* C */
|
||||
{}, /* compute_0 */
|
||||
},
|
||||
{/* ptr_aux */ aux.get(), /* null_default */ float, /* dAux */ {2048, _1{}, _0{}}}, /* aux */
|
||||
{}, /* compute_1 */
|
||||
{ /* compute_1 */
|
||||
{ /* compute_0 */
|
||||
{}, /* accum */
|
||||
{}, /* C */
|
||||
{}, /* compute_0 */
|
||||
},
|
||||
{/* ptr_aux */ F.get(), /* dAux */ {2048, _1{}, _0{}}}, /* F */
|
||||
{/* ptr_aux */ (float*) aux, /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */
|
||||
{}, /* compute_1 */
|
||||
},
|
||||
{/* ptr_aux */ (float*) F, /* dAux */ {2048, _1{}, _0{}}}, /* F */
|
||||
},
|
||||
{/* ptr_col */ bias.get(), /* null_default */ float, /* dCol */ {}}, /* bias */
|
||||
{/* ptr_col */ (float*) bias, /* null_default */ float(0), /* dCol */ {}}, /* bias */
|
||||
{}, /* compute_2 */
|
||||
{}, /* compute_3 */
|
||||
{}, /* compute_4 */
|
||||
},
|
||||
}};
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
||||
def test_evt_codegen(self):
|
||||
_, code = trace(
|
||||
_, _, code = trace(
|
||||
BIAS_CODE,
|
||||
EXAMPLE_TENSORS,
|
||||
DataType.f32,
|
||||
DataType.f32,
|
||||
MockTileDescription(),
|
||||
EpilogueScheduleType.ScheduleAuto,
|
||||
_create_mock_buffer_name_map(EXAMPLE_TENSORS),
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
code,
|
||||
|
|
|
|||
|
|
@ -115,8 +115,9 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
|
|||
output_type: DataType,
|
||||
tile_description: TileDescription,
|
||||
epilogue_schedule: EpilogueScheduleType,
|
||||
name_to_buffer: dict[str, Buffer],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> tuple[str, str]:
|
||||
) -> tuple[str, str, str]:
|
||||
cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
|
||||
assert cuda_arch >= 90, "Only SM90+ is supported for EVT"
|
||||
epilogue_functor = _trace(fn_src, example_tensors, **kwargs)
|
||||
|
|
@ -129,8 +130,9 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
|
|||
output_type,
|
||||
fusion_callbacks,
|
||||
)
|
||||
|
||||
return collective_epilogue.emit()
|
||||
evt_name, evt_code = collective_epilogue.emit()
|
||||
evt_args = _render_argument_type(epilogue_functor, name_to_buffer)
|
||||
return evt_name, evt_args, evt_code
|
||||
|
||||
# Based off of
|
||||
# https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117
|
||||
|
|
@ -167,33 +169,42 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
|
|||
)
|
||||
|
||||
buffer = IndentedBuffer()
|
||||
with buffer.set_tabwidth(2):
|
||||
|
||||
def render_argument_type(name: str, t: CutlassArgType) -> None:
|
||||
if issubclass(t, ctypes.c_byte):
|
||||
buffer.writeline(f"{{}}, /* {name} */")
|
||||
else:
|
||||
fields = [
|
||||
(fname, _get_arg_from_node(ty, name_to_buffer[name]))
|
||||
for fname, ty in t._fields_
|
||||
]
|
||||
field_strs = [f"/* {fname} */ {str(field)}" for fname, field in fields]
|
||||
buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */")
|
||||
def render_argument_type(name: str, t: CutlassArgType) -> None:
|
||||
if issubclass(t, ctypes.c_byte):
|
||||
buffer.writeline(f"{{}}, /* {name} */")
|
||||
else:
|
||||
fields = [
|
||||
(fname, _get_arg_from_node(ty, name_to_buffer[name]))
|
||||
for fname, ty in t._fields_
|
||||
]
|
||||
field_strs = [
|
||||
f"/* {fname} */ {str(field)}" for fname, field in fields
|
||||
]
|
||||
buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */")
|
||||
|
||||
def render_thread_type(name: str, t: CutlassArgType) -> None:
|
||||
if is_nested_visitor_type(t):
|
||||
buffer.writeline(f"{{ /* {name} */")
|
||||
with buffer.indent():
|
||||
for name, inner_t in t._fields_:
|
||||
render_thread_type(name, inner_t)
|
||||
buffer.writeline("},")
|
||||
else:
|
||||
render_argument_type(name, t)
|
||||
def render_thread_type(name: str, t: CutlassArgType) -> None:
|
||||
if is_nested_visitor_type(t):
|
||||
buffer.writeline(f"{{ /* {name} */")
|
||||
with buffer.indent():
|
||||
for name, inner_t in t._fields_:
|
||||
render_thread_type(name, inner_t)
|
||||
buffer.writeline("},")
|
||||
else:
|
||||
render_argument_type(name, t)
|
||||
|
||||
buffer.writeline("{{")
|
||||
with buffer.indent():
|
||||
render_thread_type("thread", epilogue_thread_type)
|
||||
|
||||
buffer.writeline("}};")
|
||||
# unroll the recursion once to address special case formatting
|
||||
# namely, no ending comma and no indentation for the outermost thread type
|
||||
buffer.writeline("{ /* thread */")
|
||||
with buffer.indent(3):
|
||||
if is_nested_visitor_type(epilogue_thread_type):
|
||||
with buffer.indent():
|
||||
for name, inner_t in epilogue_thread_type._fields_:
|
||||
render_thread_type(name, inner_t)
|
||||
else:
|
||||
render_argument_type("thread", epilogue_thread_type)
|
||||
buffer.writeline("}")
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
|
|
@ -225,11 +236,11 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
|
|||
return f"{{{', '.join([render_stride(x) for x in stride])}}}"
|
||||
|
||||
elif issubclass(arg_ty, ctypes.c_void_p):
|
||||
return f"{node.get_name()}.get()"
|
||||
return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) {node.get_name()}"
|
||||
elif (
|
||||
arg_ty in _CUTLASS_C_DTYPES
|
||||
): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently
|
||||
return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
|
||||
return f"{CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}(0)"
|
||||
elif issubclass(arg_ty, EmptyByte):
|
||||
return "{}"
|
||||
|
||||
|
|
|
|||
|
|
@ -1165,6 +1165,15 @@ class IndentedBuffer:
|
|||
self._lines: list[Union[DeferredLineBase, LineContext, str]] = []
|
||||
self._indent = initial_indent
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_tabwidth(self, tabwidth: int) -> Iterator[None]:
|
||||
prev = self.tabwidth
|
||||
try:
|
||||
self.tabwidth = tabwidth
|
||||
yield
|
||||
finally:
|
||||
self.tabwidth = prev
|
||||
|
||||
def getvaluewithlinemap(self) -> ValueWithLineMap:
|
||||
buf = StringIO()
|
||||
p = 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user