mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Cutlass] Implement Epilogue Argument emitter (#150903)
This implements epilogue visitor tree argument generation (example type [here](3fe62887d8/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp (L332))). Details: The codegen task here is to implement a function which can generate a tree of C++ structs and properly extract the correct properties from Inductor buffers and write them to the correct locations in the generated struct. To implement this with the minimum amount of code, I generate the cutlass DAGIR (the EVT internal represenation) which specifically has a pass, [pass_argument_type.py ](5e497243f7/python/cutlass/backend/evt/passes/pass_argument_type.py (L4)) which generates a nested tree of custom argument types for each node in the DAGIR. This nested tree of constructors is then passed kwargs to fill in the proper values, where the node's name is used to differentiate between different values in the kwarg dictionary. This however is non-customizable; the nested tree of EVT args is a nested tree of ctypes which looks for *actual values* so that this object can be passed directly to the cutlass-python C++ runner. Inductor on the other hand needs to fill this struct with string C++ expressions representing the values (or extracting the values from kernel launcher args). So `_render_argument_type` implements this: it iterates over the tree of types created by pass_argument_type.py and generates a string representing the nested structs, filling in C++ expressions representing the different fields. Long term plan: Long term, I will ask the nvidia to provide an overridable [visitor_factory](5e497243f7/python/cutlass/backend/evt/passes/pass_argument_type.py (L82)) which could allow us to override the behavior of pass_argument_type.py to generate the string we would like during DAGIR generation. Previously merged: * #150346 * #150345 * #150344 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150903 Approved by: https://github.com/henrylhtsang, https://github.com/eellison
This commit is contained in:
parent
8e0f9fbccf
commit
4f62dccbda
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
|
|
@ -13,54 +14,111 @@ if try_import_cutlass():
|
|||
LayoutType = cutlass_lib.LayoutType
|
||||
DataType = cutlass_lib.DataType
|
||||
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
|
||||
_render_argument_type,
|
||||
_trace,
|
||||
CutlassTensor,
|
||||
trace,
|
||||
)
|
||||
|
||||
BIAS_CODE = """def example_epilogue(accum, C, aux, bias):
|
||||
F = accum + C + aux
|
||||
E = relu(F) + bias
|
||||
D = E + F
|
||||
return D, F"""
|
||||
|
||||
TYPE_C = DataType.f32
|
||||
M = 4224
|
||||
N = 2048
|
||||
BIAS = CutlassTensor(shape=(M, 1), element=TYPE_C, layout_tag=LayoutType.RowMajor)
|
||||
|
||||
EXAMPLE_TENSORS = {
|
||||
"accum": CutlassTensor(
|
||||
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
||||
),
|
||||
"bias": BIAS,
|
||||
# "beta": 0.5, TODO: mlazos support scalars
|
||||
# "alpha": 0.5, TODO: mlazos support scalars
|
||||
"D": CutlassTensor(
|
||||
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
||||
),
|
||||
"C": CutlassTensor(
|
||||
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
||||
),
|
||||
"F": CutlassTensor(
|
||||
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
||||
),
|
||||
"aux": CutlassTensor(
|
||||
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
||||
),
|
||||
}
|
||||
|
||||
class MockTileDescription:
|
||||
threadblock_shape = (128, 128, 8)
|
||||
|
||||
def _create_mock_buffer_name_map(example_tensors):
|
||||
class MockNode:
|
||||
def __init__(self, name, stride, dtype):
|
||||
self.name = name
|
||||
self.dtype = dtype
|
||||
self.stride = stride
|
||||
|
||||
def get_layout(self):
|
||||
class MockLayout:
|
||||
def __init__(self, stride, dtype):
|
||||
self.dtype = dtype
|
||||
self.stride = stride
|
||||
|
||||
return MockLayout(self.stride, self.dtype)
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
|
||||
name_to_buffer = {}
|
||||
for name, tensor in example_tensors.items():
|
||||
if isinstance(tensor, CutlassTensor):
|
||||
name_to_buffer[name] = MockNode(name, tensor.stride, torch.float32)
|
||||
|
||||
return name_to_buffer
|
||||
|
||||
|
||||
class TestCutlassEVT(TestCase):
|
||||
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
||||
def test_evt_codegen(self):
|
||||
bias_code = """def example_epilogue(accum, alpha, C, beta, aux, bias):
|
||||
F = alpha * accum + (beta * C + aux)
|
||||
E = relu(F + 1) + bias
|
||||
D = E + F
|
||||
return D, F"""
|
||||
def test_evt_argument_codegen(self):
|
||||
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS)
|
||||
|
||||
type_C = DataType.f32
|
||||
m = 4224
|
||||
n = 2048
|
||||
bias = CutlassTensor(
|
||||
shape=(m, 1), element=type_C, layout_tag=LayoutType.RowMajor
|
||||
self.assertExpectedInline(
|
||||
_render_argument_type(
|
||||
epilogue_functor, _create_mock_buffer_name_map(EXAMPLE_TENSORS)
|
||||
),
|
||||
"""\
|
||||
{{
|
||||
{ /* thread */
|
||||
{ /* F */
|
||||
{ /* compute_1 */
|
||||
{ /* compute_0 */
|
||||
{}, /* accum */
|
||||
{}, /* C */
|
||||
{}, /* compute_0 */
|
||||
},
|
||||
{/* ptr_aux */ aux.get(), /* null_default */ float, /* dAux */ {2048, _1{}, _0{}}}, /* aux */
|
||||
{}, /* compute_1 */
|
||||
},
|
||||
{/* ptr_aux */ F.get(), /* dAux */ {2048, _1{}, _0{}}}, /* F */
|
||||
},
|
||||
{/* ptr_col */ bias.get(), /* null_default */ float, /* dCol */ {}}, /* bias */
|
||||
{}, /* compute_2 */
|
||||
{}, /* compute_3 */
|
||||
{}, /* compute_4 */
|
||||
},
|
||||
}};
|
||||
""",
|
||||
)
|
||||
|
||||
examples_tensors = {
|
||||
"accum": CutlassTensor(
|
||||
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
|
||||
),
|
||||
"bias": bias,
|
||||
"beta": 0.5,
|
||||
"alpha": 0.5,
|
||||
"D": CutlassTensor(
|
||||
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
|
||||
),
|
||||
"C": CutlassTensor(
|
||||
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
|
||||
),
|
||||
"F": CutlassTensor(
|
||||
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
|
||||
),
|
||||
"aux": CutlassTensor(
|
||||
element=DataType.f32, shape=(m, n), layout_tag=LayoutType.RowMajor
|
||||
),
|
||||
}
|
||||
|
||||
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
||||
def test_evt_codegen(self):
|
||||
_, code = trace(
|
||||
bias_code,
|
||||
examples_tensors,
|
||||
BIAS_CODE,
|
||||
EXAMPLE_TENSORS,
|
||||
DataType.f32,
|
||||
DataType.f32,
|
||||
MockTileDescription(),
|
||||
|
|
@ -82,20 +140,13 @@ using TensorC = cutlass::epilogue::fusion::Sm90SrcFetch<float>;
|
|||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
using Alpha = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
|
||||
float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
|
||||
>;
|
||||
|
||||
using AuxDescriptor = cutlass::epilogue::collective::detail::AuxLoadDescriptor\
|
||||
<EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float>;
|
||||
using AuxDescriptor = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, \
|
||||
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float>;
|
||||
|
||||
using Aux = cutlass::epilogue::fusion::Sm90AuxLoad<
|
||||
AuxDescriptor::Stages, typename AuxDescriptor::EpilogueTile, float,
|
||||
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename AuxDescriptor::SmemLayoutAtom, typename AuxDescriptor::CopyOpS2R
|
||||
>;
|
||||
|
||||
using Beta = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
|
||||
float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
|
||||
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename AuxDescriptor::SmemLayoutAtom, \
|
||||
typename AuxDescriptor::CopyOpS2R
|
||||
>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
|
|
@ -104,44 +155,24 @@ using Bias = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
|||
>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::plus, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<
|
||||
Compute0,
|
||||
Alpha,
|
||||
Accum>;
|
||||
Accum,
|
||||
TensorC>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::plus, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<
|
||||
Compute1,
|
||||
Beta,
|
||||
TensorC>;
|
||||
|
||||
using Compute2 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::plus, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using EVTCompute2 = cutlass::epilogue::fusion::Sm90EVT<
|
||||
Compute2,
|
||||
EVTCompute1,
|
||||
Aux>;
|
||||
|
||||
using Compute3 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::plus, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using EVTCompute3 = cutlass::epilogue::fusion::Sm90EVT<
|
||||
Compute3,
|
||||
EVTCompute0,
|
||||
EVTCompute2>;
|
||||
Aux>;
|
||||
|
||||
using FDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float
|
||||
|
|
@ -149,17 +180,23 @@ using FDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
|||
|
||||
using F = cutlass::epilogue::fusion::Sm90AuxStore<
|
||||
FDescriptor::Stages, typename FDescriptor::EpilogueTile, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, \
|
||||
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename FDescriptor::SmemLayoutAtom,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, cute::Stride<int64_t, cute::Int<1>, \
|
||||
cute::Int<0>>, typename FDescriptor::SmemLayoutAtom,
|
||||
typename FDescriptor::CopyOpR2S
|
||||
>;
|
||||
|
||||
using EVTF = cutlass::epilogue::fusion::Sm90EVT<
|
||||
F,
|
||||
EVTCompute3>;
|
||||
EVTCompute1>;
|
||||
|
||||
using Imm10 = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
|
||||
float, cute::Stride<cute::Int<0>, cute::Int<0>, cute::Int<0>>, 1, cutlass::multiplies
|
||||
using Compute2 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::epilogue::thread::ReLu, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using Compute3 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::plus, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using Compute4 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
|
|
@ -167,39 +204,20 @@ using Compute4 = cutlass::epilogue::fusion::Sm90Compute<
|
|||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using Compute5 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::epilogue::thread::ReLu, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using Compute6 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::plus, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using Compute7 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::plus, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest
|
||||
>;
|
||||
|
||||
using DagCompute7 = cutlass::epilogue::fusion::Sm90TopologicalVisitor<
|
||||
using DagCompute4 = cutlass::epilogue::fusion::Sm90TopologicalVisitor<
|
||||
float,
|
||||
cute::tuple<
|
||||
cute::seq<>,
|
||||
cute::seq<>,
|
||||
cute::seq<>,
|
||||
cute::seq<0, 2>,
|
||||
cute::seq<3>,
|
||||
cute::seq<4, 1>,
|
||||
cute::seq<5, 0>,
|
||||
cute::seq<0>,
|
||||
cute::seq<2, 1>,
|
||||
cute::seq<3, 0>,
|
||||
>,
|
||||
EVTF,
|
||||
Bias,
|
||||
Imm10,
|
||||
Compute4,
|
||||
Compute5,
|
||||
Compute6,
|
||||
Compute7
|
||||
Compute2,
|
||||
Compute3,
|
||||
Compute4
|
||||
>;
|
||||
|
||||
using ElementD = float;
|
||||
|
|
|
|||
|
|
@ -1,12 +1,31 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Union
|
||||
|
||||
from torch._inductor.ir import ComputedBuffer, InputBuffer
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ..cutlass_utils import try_import_cutlass
|
||||
|
||||
|
||||
EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace
|
||||
Buffer = Union[ComputedBuffer, InputBuffer]
|
||||
CutlassTupleType = Any # cutlass.backend.c_types.tuple_factory_.<locals>.TupleType
|
||||
CutlassVisitorType = Any # cutlass.backend.c_types.visitor_factory.<locals>.VisitorType
|
||||
CutlassArgType = (
|
||||
Any # Can be a CutlassTupleType, CutlassVisitorType, EmptyByte, or ctype.c_void_p
|
||||
)
|
||||
|
||||
|
||||
if try_import_cutlass():
|
||||
import ast
|
||||
import ctypes
|
||||
import textwrap
|
||||
|
||||
from cutlass.backend.c_types import ( # type: ignore[import-untyped, import-not-found]
|
||||
EmptyByte,
|
||||
)
|
||||
from cutlass.backend.epilogue import ( # type: ignore[import-untyped, import-not-found]
|
||||
dtype2ctype,
|
||||
)
|
||||
from cutlass.backend.evt import ( # type: ignore[import-untyped, import-not-found]
|
||||
EpilogueFunctorVisitor,
|
||||
)
|
||||
|
|
@ -25,6 +44,9 @@ if try_import_cutlass():
|
|||
from cutlass_library import DataType, EpilogueScheduleType, TileDescription
|
||||
|
||||
from torch._inductor.codegen.cuda import cuda_env
|
||||
from torch._inductor.utils import IndentedBuffer
|
||||
|
||||
_CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated]
|
||||
|
||||
def trace(
|
||||
fn_src: str,
|
||||
|
|
@ -33,8 +55,8 @@ if try_import_cutlass():
|
|||
output_type: DataType,
|
||||
tile_description: TileDescription,
|
||||
epilogue_schedule: EpilogueScheduleType,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: dict[str, Any],
|
||||
) -> tuple[str, str]:
|
||||
cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
|
||||
assert cuda_arch >= 90, "Only SM90+ is supported for EVT"
|
||||
epilogue_functor = _trace(fn_src, example_tensors, **kwargs)
|
||||
|
|
@ -54,13 +76,15 @@ if try_import_cutlass():
|
|||
# https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117
|
||||
# This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function
|
||||
# The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval
|
||||
def _trace(fn_src, example_tensors, **kwargs):
|
||||
def _trace(
|
||||
fn_src: str, example_tensors: dict[str, CutlassTensor], **kwargs: Any
|
||||
) -> EpilogueFunctor:
|
||||
class EpilogueFunctor(PythonASTFrontend):
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: dict[str, Any]):
|
||||
self.source = textwrap.dedent(fn_src)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def parse(self, example_inputs):
|
||||
def parse(self, example_inputs: dict[str, CutlassTensor]) -> None:
|
||||
self.example_inputs = example_inputs
|
||||
self.ast = ast.parse(self.source)
|
||||
self.visit(self.ast)
|
||||
|
|
@ -68,3 +92,85 @@ if try_import_cutlass():
|
|||
epilogue_functor = EpilogueFunctor(**kwargs)
|
||||
epilogue_functor.trace(example_tensors)
|
||||
return epilogue_functor
|
||||
|
||||
def _render_argument_type(
|
||||
epilogue_functor: EpilogueFunctor,
|
||||
name_to_buffer: dict[str, Buffer],
|
||||
) -> str:
|
||||
epilogue_thread_type = epilogue_functor.epilogue_thread_type
|
||||
|
||||
# Fragile, but this is the only way to guarantee t is expected type because t is a local class
|
||||
def is_nested_visitor_type(t: type) -> bool:
|
||||
return (
|
||||
".".join([t.__module__, t.__qualname__])
|
||||
== "cutlass.backend.c_types.visitor_factory.<locals>.VisitorType"
|
||||
)
|
||||
|
||||
buffer = IndentedBuffer()
|
||||
|
||||
def render_argument_type(name: str, t: CutlassArgType) -> None:
|
||||
if issubclass(t, ctypes.c_byte):
|
||||
buffer.writeline(f"{{}}, /* {name} */")
|
||||
else:
|
||||
fields = [
|
||||
(fname, _get_arg_from_node(ty, name_to_buffer[name]))
|
||||
for fname, ty in t._fields_
|
||||
]
|
||||
field_strs = [f"/* {fname} */ {str(field)}" for fname, field in fields]
|
||||
buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */")
|
||||
|
||||
def render_thread_type(name: str, t: CutlassArgType) -> None:
|
||||
if is_nested_visitor_type(t):
|
||||
buffer.writeline(f"{{ /* {name} */")
|
||||
with buffer.indent():
|
||||
for name, inner_t in t._fields_:
|
||||
render_thread_type(name, inner_t)
|
||||
buffer.writeline("},")
|
||||
else:
|
||||
render_argument_type(name, t)
|
||||
|
||||
buffer.writeline("{{")
|
||||
with buffer.indent():
|
||||
render_thread_type("thread", epilogue_thread_type)
|
||||
|
||||
buffer.writeline("}};")
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
def _get_arg_from_node(arg_ty: type, node: Buffer) -> str:
|
||||
from ..cuda_template import CUTLASSTemplate
|
||||
|
||||
# Today, arguments are either a pointer to the
|
||||
# node's memory, a stride tuple, the datatype
|
||||
# Once again, need to check for local class type for stride tuple
|
||||
if (
|
||||
str(arg_ty)
|
||||
== "<class 'cutlass.backend.c_types.tuple_factory_.<locals>.TupleType'>"
|
||||
):
|
||||
DEFAULT_STRIDE_LEN = 3
|
||||
assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN
|
||||
stride = [int(x) for x in node.get_layout().stride]
|
||||
for _ in range(DEFAULT_STRIDE_LEN - len(stride)):
|
||||
stride.append(0)
|
||||
|
||||
def render_stride(x: int) -> str:
|
||||
# Handle EBO for 0 and 1
|
||||
if x == 0:
|
||||
return "_0{}"
|
||||
elif x == 1:
|
||||
return "_1{}"
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
return f"{{{', '.join([render_stride(x) for x in stride])}}}"
|
||||
|
||||
elif issubclass(arg_ty, ctypes.c_void_p):
|
||||
return f"{node.get_name()}.get()"
|
||||
elif (
|
||||
arg_ty in _CUTLASS_C_DTYPES
|
||||
): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently
|
||||
return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
|
||||
elif issubclass(arg_ty, EmptyByte):
|
||||
return "{}"
|
||||
|
||||
raise NotImplementedError(f"Unsupported arg type: {arg_ty}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user