[Cutlass] Implement Epilogue Argument emitter (#150903)

This implements epilogue visitor tree argument generation (example type [here](3fe62887d8/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp (L332))).

Details:
The codegen task here is to implement a function which can generate a tree of C++ structs and properly extract the correct properties from Inductor buffers and write them to the correct locations in the generated struct. To implement this with the minimum amount of code, I generate the cutlass DAGIR (the EVT internal represenation) which specifically has a pass, [pass_argument_type.py ](5e497243f7/python/cutlass/backend/evt/passes/pass_argument_type.py (L4)) which generates a nested tree of custom argument types for each node in the DAGIR. This nested tree of constructors is then passed kwargs to fill in the proper values, where the node's name is used to differentiate between different values in the kwarg dictionary. This however is non-customizable; the nested tree of EVT args is a nested tree of ctypes which looks for *actual values* so that this object can be passed directly to the cutlass-python C++ runner. Inductor on the other hand needs to fill this struct with string C++ expressions representing the values (or extracting the values from kernel launcher args). So `_render_argument_type` implements this: it iterates over the tree of types created by pass_argument_type.py and generates a string representing the nested structs, filling in C++ expressions representing the different fields.

Long term plan:
Long term, I will ask the nvidia to provide an overridable [visitor_factory](5e497243f7/python/cutlass/backend/evt/passes/pass_argument_type.py (L82)) which could allow us to override the behavior of pass_argument_type.py to generate the string we would like during DAGIR generation.

Previously merged:
* #150346
* #150345
* #150344

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150903
Approved by: https://github.com/henrylhtsang, https://github.com/eellison
This commit is contained in:
Michael Lazos 2025-04-17 00:16:37 -07:00 committed by PyTorch MergeBot
parent 8e0f9fbccf
commit 4f62dccbda
2 changed files with 231 additions and 107 deletions

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch._dynamo.test_case import TestCase
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
@ -13,54 +14,111 @@ if try_import_cutlass():
LayoutType = cutlass_lib.LayoutType
DataType = cutlass_lib.DataType
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
_render_argument_type,
_trace,
CutlassTensor,
trace,
)
BIAS_CODE = """def example_epilogue(accum, C, aux, bias):
F = accum + C + aux
E = relu(F) + bias
D = E + F
return D, F"""
TYPE_C = DataType.f32
M = 4224
N = 2048
BIAS = CutlassTensor(shape=(M, 1), element=TYPE_C, layout_tag=LayoutType.RowMajor)
EXAMPLE_TENSORS = {
"accum": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
"bias": BIAS,
# "beta": 0.5, TODO: mlazos support scalars
# "alpha": 0.5, TODO: mlazos support scalars
"D": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
"C": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
"F": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
"aux": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
}
class MockTileDescription:
threadblock_shape = (128, 128, 8)
def _create_mock_buffer_name_map(example_tensors):
class MockNode:
def __init__(self, name, stride, dtype):
self.name = name
self.dtype = dtype
self.stride = stride
def get_layout(self):
class MockLayout:
def __init__(self, stride, dtype):
self.dtype = dtype
self.stride = stride
return MockLayout(self.stride, self.dtype)
def get_name(self):
return self.name
name_to_buffer = {}
for name, tensor in example_tensors.items():
if isinstance(tensor, CutlassTensor):
name_to_buffer[name] = MockNode(name, tensor.stride, torch.float32)
return name_to_buffer
class TestCutlassEVT(TestCase):
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_evt_codegen(self):
bias_code = """def example_epilogue(accum, alpha, C, beta, aux, bias):
F = alpha * accum + (beta * C + aux)
E = relu(F + 1) + bias
D = E + F
return D, F"""
def test_evt_argument_codegen(self):
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS)
type_C = DataType.f32
m = 4224
n = 2048
bias = CutlassTensor(
shape=(m, 1), element=type_C, layout_tag=LayoutType.RowMajor
self.assertExpectedInline(
_render_argument_type(
epilogue_functor, _create_mock_buffer_name_map(EXAMPLE_TENSORS)
),
"""\
{{
{ /* thread */
{ /* F */
{ /* compute_1 */
{ /* compute_0 */
{}, /* accum */
{}, /* C */
{}, /* compute_0 */
},
{/* ptr_aux */ aux.get(), /* null_default */ float, /* dAux */ {2048, _1{}, _0{}}}, /* aux */
{}, /* compute_1 */
},
{/* ptr_aux */ F.get(), /* dAux */ {2048, _1{}, _0{}}}, /* F */
},
{/* ptr_col */ bias.get(), /* null_default */ float, /* dCol */ {}}, /* bias */
{}, /* compute_2 */
{}, /* compute_3 */
{}, /* compute_4 */
},
}};
""",
)
examples_tensors = {
"accum": CutlassTensor(
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
),
"bias": bias,
"beta": 0.5,
"alpha": 0.5,
"D": CutlassTensor(
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
),
"C": CutlassTensor(
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
),
"F": CutlassTensor(
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
),
"aux": CutlassTensor(
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
),
}
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_evt_codegen(self):
_, code = trace(
bias_code,
examples_tensors,
BIAS_CODE,
EXAMPLE_TENSORS,
DataType.f32,
DataType.f32,
MockTileDescription(),
@ -82,20 +140,13 @@ using TensorC = cutlass::epilogue::fusion::Sm90SrcFetch<float>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Alpha = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
>;
using AuxDescriptor = cutlass::epilogue::collective::detail::AuxLoadDescriptor\
<EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float>;
using AuxDescriptor = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, \
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float>;
using Aux = cutlass::epilogue::fusion::Sm90AuxLoad<
AuxDescriptor::Stages, typename AuxDescriptor::EpilogueTile, float,
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename AuxDescriptor::SmemLayoutAtom, typename AuxDescriptor::CopyOpS2R
>;
using Beta = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename AuxDescriptor::SmemLayoutAtom, \
typename AuxDescriptor::CopyOpS2R
>;
using Bias = cutlass::epilogue::fusion::Sm90ColBroadcast<
@ -104,44 +155,24 @@ using Bias = cutlass::epilogue::fusion::Sm90ColBroadcast<
>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<
Compute0,
Alpha,
Accum>;
Accum,
TensorC>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<
Compute1,
Beta,
TensorC>;
using Compute2 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute2 = cutlass::epilogue::fusion::Sm90EVT<
Compute2,
EVTCompute1,
Aux>;
using Compute3 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute3 = cutlass::epilogue::fusion::Sm90EVT<
Compute3,
EVTCompute0,
EVTCompute2>;
Aux>;
using FDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float
@ -149,17 +180,23 @@ using FDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
using F = cutlass::epilogue::fusion::Sm90AuxStore<
FDescriptor::Stages, typename FDescriptor::EpilogueTile, float,
cutlass::FloatRoundStyle::round_to_nearest, \
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename FDescriptor::SmemLayoutAtom,
cutlass::FloatRoundStyle::round_to_nearest, cute::Stride<int64_t, cute::Int<1>, \
cute::Int<0>>, typename FDescriptor::SmemLayoutAtom,
typename FDescriptor::CopyOpR2S
>;
using EVTF = cutlass::epilogue::fusion::Sm90EVT<
F,
EVTCompute3>;
EVTCompute1>;
using Imm10 = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
using Compute2 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::ReLu, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using Compute3 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using Compute4 = cutlass::epilogue::fusion::Sm90Compute<
@ -167,39 +204,20 @@ using Compute4 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::FloatRoundStyle::round_to_nearest
>;
using Compute5 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::ReLu, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using Compute6 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using Compute7 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using DagCompute7 = cutlass::epilogue::fusion::Sm90TopologicalVisitor<
using DagCompute4 = cutlass::epilogue::fusion::Sm90TopologicalVisitor<
float,
cute::tuple<
cute::seq<>,
cute::seq<>,
cute::seq<>,
cute::seq<0, 2>,
cute::seq<3>,
cute::seq<4, 1>,
cute::seq<5, 0>,
cute::seq<0>,
cute::seq<2, 1>,
cute::seq<3, 0>,
>,
EVTF,
Bias,
Imm10,
Compute4,
Compute5,
Compute6,
Compute7
Compute2,
Compute3,
Compute4
>;
using ElementD = float;

View File

@ -1,12 +1,31 @@
# mypy: allow-untyped-defs
from typing import Any, Union
from torch._inductor.ir import ComputedBuffer, InputBuffer
from torch.utils._ordered_set import OrderedSet
from ..cutlass_utils import try_import_cutlass
EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace
Buffer = Union[ComputedBuffer, InputBuffer]
CutlassTupleType = Any # cutlass.backend.c_types.tuple_factory_.<locals>.TupleType
CutlassVisitorType = Any # cutlass.backend.c_types.visitor_factory.<locals>.VisitorType
CutlassArgType = (
Any # Can be a CutlassTupleType, CutlassVisitorType, EmptyByte, or ctype.c_void_p
)
if try_import_cutlass():
import ast
import ctypes
import textwrap
from cutlass.backend.c_types import ( # type: ignore[import-untyped, import-not-found]
EmptyByte,
)
from cutlass.backend.epilogue import ( # type: ignore[import-untyped, import-not-found]
dtype2ctype,
)
from cutlass.backend.evt import ( # type: ignore[import-untyped, import-not-found]
EpilogueFunctorVisitor,
)
@ -25,6 +44,9 @@ if try_import_cutlass():
from cutlass_library import DataType, EpilogueScheduleType, TileDescription
from torch._inductor.codegen.cuda import cuda_env
from torch._inductor.utils import IndentedBuffer
_CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated]
def trace(
fn_src: str,
@ -33,8 +55,8 @@ if try_import_cutlass():
output_type: DataType,
tile_description: TileDescription,
epilogue_schedule: EpilogueScheduleType,
**kwargs,
):
**kwargs: dict[str, Any],
) -> tuple[str, str]:
cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
assert cuda_arch >= 90, "Only SM90+ is supported for EVT"
epilogue_functor = _trace(fn_src, example_tensors, **kwargs)
@ -54,13 +76,15 @@ if try_import_cutlass():
# https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117
# This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function
# The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval
def _trace(fn_src, example_tensors, **kwargs):
def _trace(
fn_src: str, example_tensors: dict[str, CutlassTensor], **kwargs: Any
) -> EpilogueFunctor:
class EpilogueFunctor(PythonASTFrontend):
def __init__(self, **kwargs):
def __init__(self, **kwargs: dict[str, Any]):
self.source = textwrap.dedent(fn_src)
super().__init__(**kwargs)
def parse(self, example_inputs):
def parse(self, example_inputs: dict[str, CutlassTensor]) -> None:
self.example_inputs = example_inputs
self.ast = ast.parse(self.source)
self.visit(self.ast)
@ -68,3 +92,85 @@ if try_import_cutlass():
epilogue_functor = EpilogueFunctor(**kwargs)
epilogue_functor.trace(example_tensors)
return epilogue_functor
def _render_argument_type(
epilogue_functor: EpilogueFunctor,
name_to_buffer: dict[str, Buffer],
) -> str:
epilogue_thread_type = epilogue_functor.epilogue_thread_type
# Fragile, but this is the only way to guarantee t is expected type because t is a local class
def is_nested_visitor_type(t: type) -> bool:
return (
".".join([t.__module__, t.__qualname__])
== "cutlass.backend.c_types.visitor_factory.<locals>.VisitorType"
)
buffer = IndentedBuffer()
def render_argument_type(name: str, t: CutlassArgType) -> None:
if issubclass(t, ctypes.c_byte):
buffer.writeline(f"{{}}, /* {name} */")
else:
fields = [
(fname, _get_arg_from_node(ty, name_to_buffer[name]))
for fname, ty in t._fields_
]
field_strs = [f"/* {fname} */ {str(field)}" for fname, field in fields]
buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */")
def render_thread_type(name: str, t: CutlassArgType) -> None:
if is_nested_visitor_type(t):
buffer.writeline(f"{{ /* {name} */")
with buffer.indent():
for name, inner_t in t._fields_:
render_thread_type(name, inner_t)
buffer.writeline("},")
else:
render_argument_type(name, t)
buffer.writeline("{{")
with buffer.indent():
render_thread_type("thread", epilogue_thread_type)
buffer.writeline("}};")
return buffer.getvalue()
def _get_arg_from_node(arg_ty: type, node: Buffer) -> str:
from ..cuda_template import CUTLASSTemplate
# Today, arguments are either a pointer to the
# node's memory, a stride tuple, the datatype
# Once again, need to check for local class type for stride tuple
if (
str(arg_ty)
== "<class 'cutlass.backend.c_types.tuple_factory_.<locals>.TupleType'>"
):
DEFAULT_STRIDE_LEN = 3
assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN
stride = [int(x) for x in node.get_layout().stride]
for _ in range(DEFAULT_STRIDE_LEN - len(stride)):
stride.append(0)
def render_stride(x: int) -> str:
# Handle EBO for 0 and 1
if x == 0:
return "_0{}"
elif x == 1:
return "_1{}"
else:
return str(x)
return f"{{{', '.join([render_stride(x) for x in stride])}}}"
elif issubclass(arg_ty, ctypes.c_void_p):
return f"{node.get_name()}.get()"
elif (
arg_ty in _CUTLASS_C_DTYPES
): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently
return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
elif issubclass(arg_ty, EmptyByte):
return "{}"
raise NotImplementedError(f"Unsupported arg type: {arg_ty}")