pytorch/test/inductor/test_cutlass_evt.py

565 lines
19 KiB
Python

# Owner(s): ["module: inductor"]
import unittest
import sympy
import torch
from torch._dynamo.test_case import TestCase
from torch._inductor.codegen.cuda.cutlass_utils import (
torch_dtype_to_cutlass_type,
try_import_cutlass,
)
from torch._inductor.ir import ComputedBuffer, FixedLayout, PermuteView, Pointwise
from torch._inductor.scheduler import BaseSchedulerNode
from torch._inductor.utils import OrderedSet
from torch.testing._internal.common_cuda import SM90OrLater
from torch.testing._internal.inductor_utils import (
HAS_CPU,
HAS_CUDA_AND_TRITON,
MockGraphHandler,
)
if try_import_cutlass():
import cutlass_library as cutlass_lib
from cutlass_library import EpilogueScheduleType
LayoutType = cutlass_lib.LayoutType
DataType = cutlass_lib.DataType
from cutlass_cppgen.backend.evt.ir.tensor import Tensor as CutlassTensor
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
_render_argument_type,
_trace,
trace,
)
BIAS_CODE = """def example_epilogue(accum, C, aux, bias):
F = accum + C + aux
E = relu(F) + bias
D = E + F
return D, F"""
TYPE_C = DataType.f32
M = 4224
N = 2048
BIAS = CutlassTensor(shape=(M, 1), element=TYPE_C, layout_tag=LayoutType.RowMajor)
EXAMPLE_TENSORS = {
"accum": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
"bias": BIAS,
# "beta": 0.5, TODO: mlazos support scalars
# "alpha": 0.5, TODO: mlazos support scalars
"D": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
"C": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
"F": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
"aux": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
}
class MockTileDescription:
threadblock_shape = (128, 128, 8)
def _create_mock_buffer_name_map(example_tensors):
name_to_buffer = {}
for name, tensor in example_tensors.items():
if isinstance(tensor, CutlassTensor):
name_to_buffer[name] = MockComputedBuffer(
name, None, torch.float32, tensor.shape, tensor.stride
)
return name_to_buffer
class MockSchedulerNode(BaseSchedulerNode):
def __init__(self, node, last_usage=None):
self.node = node
self.last_usage = last_usage or OrderedSet()
class MockComputedBuffer(ComputedBuffer):
def __init__(self, name, inner_fn, dtype, size, strides=None):
self.name = name
ranges = [sympy.Integer(x) for x in size]
self.data = Pointwise(
device=None, dtype=dtype, inner_fn=inner_fn, ranges=ranges
)
self.layout = FixedLayout(None, dtype, ranges, strides)
def get_name(self):
return self.name
def num_reads(self):
# Needed to not inline in ComputedBuffer
return 1
class TestCutlassEVT(TestCase):
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_py_codegen_accumulator_return(self):
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
from torch._inductor.virtualized import V
size = (100, 300, 200)
buf0 = MockComputedBuffer("buf0", None, torch.float32, size)
buf1 = MockComputedBuffer("buf1", None, torch.float32, size)
buf2 = MockComputedBuffer("buf2", None, torch.float32, size)
# buf0 is acc
# buf1 is external
def inner_fn_buf3(index):
tmp0 = buf0.make_loader()(index)
tmp1 = buf1.make_loader()(index)
tmp2 = buf2.make_loader()(index)
return tmp0 * tmp1 + tmp2
def inner_fn_buf4(index):
tmp0 = buf0.make_loader()(index)
tmp3 = buf3.make_loader()(index)
return tmp0 + tmp3
buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size)
buf4 = MockComputedBuffer("buf4", inner_fn_buf4, torch.float32, size)
with V.set_graph_handler(
MockGraphHandler(
{"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4}
)
):
reads, writes, renames, code = CutlassEVTCodegen.ir_to_evt_python_code(
"buf0",
[
MockSchedulerNode(buf3),
MockSchedulerNode(buf4, last_usage=OrderedSet(["buf3"])),
],
OrderedSet([]),
)
self.assertExpectedInline(reads, """['buf1', 'buf2']""")
self.assertExpectedInline(writes, """['buf0', 'buf3', 'buf4']""")
self.assertExpectedInline(
renames,
"""{'accum': 'buf0', 'tmp_0': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'tmp_2': 'buf3', 'D': 'buf4'}""",
)
self.assertExpectedInline(
code,
"""\
def fn(accum, buf1, buf2):
tmp_0 = accum
tmp_1 = tmp_0 * buf1
tmp_2 = tmp_1 + buf2
D = tmp_0 + tmp_2
return tmp_0, tmp_2, D""",
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_py_codegen_disjoint_read_indexing(self):
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
from torch._inductor.virtualized import V
size = (100, 300, 200)
buf0 = MockComputedBuffer("buf0", None, torch.float32, size)
permuted_buf_0 = PermuteView.create(buf0, [1, 0, 2])
buf1 = MockComputedBuffer("buf1", None, torch.float32, size)
buf2 = MockComputedBuffer("buf2", None, torch.float32, size)
# buf0 is acc
# buf1 is external
def inner_fn_buf3(index):
tmp0 = permuted_buf_0.make_loader()(index)
tmp1 = buf1.make_loader()(index)
tmp2 = buf2.make_loader()(index)
return tmp0 * tmp1 + tmp2
def inner_fn_buf4(index):
tmp0 = buf0.make_loader()(index)
tmp3 = buf3.make_loader()(index)
return tmp0 + tmp3
buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size)
buf4 = MockComputedBuffer("buf4", inner_fn_buf4, torch.float32, size)
with V.set_graph_handler(
MockGraphHandler(
{"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4}
)
):
result = None
try:
CutlassEVTCodegen.ir_to_evt_python_code(
"buf0",
[MockSchedulerNode(buf3), MockSchedulerNode(buf4)],
OrderedSet([]),
)
except NotImplementedError as e:
result = e
self.assertExpectedInline(
str(result),
"""Unsupported indexing for buf0 with index 200*i0 + 60000*i1 + i2, \
index strides [200, 60000, 1], and layout stride [60000, 200, 1]""",
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_py_codegen_broadcasting(self):
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
from torch._inductor.virtualized import V
size = (100, 300, 200)
buf0 = MockComputedBuffer("buf0", None, torch.float32, size)
buf1 = MockComputedBuffer("buf1", None, torch.float32, size)
buf2 = MockComputedBuffer("buf2", None, torch.float32, size)
# buf0 is acc
# buf1 is external
def inner_fn_buf3(index):
tmp0 = buf0.make_loader()(index)
tmp1 = buf1.make_loader()(index)
tmp2 = buf2.make_loader()(index)
return tmp0 * tmp1 + tmp2
def inner_fn_buf4(index):
tmp0 = buf0.make_loader()(index)
tmp3 = buf3.make_loader()(index)
return tmp0 + tmp3 * tmp3
buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size)
buf4 = MockComputedBuffer(
"buf4", inner_fn_buf4, torch.float32, (100, 300, 1)
) # broadcast
with V.set_graph_handler(
MockGraphHandler(
{"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4}
)
):
reads, writes, renames, code = CutlassEVTCodegen.ir_to_evt_python_code(
"buf0",
[
MockSchedulerNode(buf3),
MockSchedulerNode(buf4, last_usage=OrderedSet(["buf0"])),
],
OrderedSet([]),
)
self.assertExpectedInline(reads, """['buf1', 'buf2']""")
self.assertExpectedInline(writes, """['buf0', 'buf3', 'buf4']""")
self.assertExpectedInline(
renames,
"""{'accum': 'buf0', 'tmp_0': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'tmp_2': 'buf3', 'D': 'buf4'}""",
)
self.assertExpectedInline(
code,
"""\
def fn(accum, buf1, buf2):
tmp_0 = accum
tmp_1 = tmp_0 * buf1
tmp_2 = tmp_1 + buf2
tmp_3 = tmp_2 * tmp_2
D = tmp_0 + tmp_3
return tmp_0, tmp_2, D""",
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_py_codegen(self):
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
from torch._inductor.virtualized import V
size = (100, 300, 200)
buf0 = MockComputedBuffer("buf0", None, torch.float32, size)
buf1 = MockComputedBuffer("buf1", None, torch.float32, size)
buf2 = MockComputedBuffer("buf2", None, torch.float32, size)
# buf0 is acc
# buf1 is external
def inner_fn_buf3(index):
tmp0 = buf0.make_loader()(index)
tmp1 = buf1.make_loader()(index)
tmp2 = buf2.make_loader()(index)
return tmp0 * tmp1 + tmp2
def inner_fn_buf4(index):
tmp0 = buf0.make_loader()(index)
tmp3 = buf3.make_loader()(index)
return tmp0 + tmp3
buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size)
buf4 = MockComputedBuffer("buf4", inner_fn_buf4, torch.float32, size)
with V.set_graph_handler(
MockGraphHandler(
{"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4}
)
):
reads, writes, renames, code = CutlassEVTCodegen.ir_to_evt_python_code(
"buf0",
[
MockSchedulerNode(buf3),
MockSchedulerNode(buf4),
],
OrderedSet(["buf0"]),
)
self.assertExpectedInline(reads, """['buf1', 'buf2']""")
self.assertExpectedInline(writes, """['buf3', 'buf4']""")
self.assertExpectedInline(
renames,
"""{'accum': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'tmp_1': 'buf3', 'D': 'buf4'}""",
)
self.assertExpectedInline(
code,
"""\
def fn(accum, buf1, buf2):
tmp_0 = accum * buf1
tmp_1 = tmp_0 + buf2
D = accum + tmp_1
return tmp_1, D""",
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_example_tensor_creation(self):
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
create_example_tensors,
)
from torch._inductor.virtualized import V
with V.set_graph_handler(MockGraphHandler({})):
row_major_buf0 = MockComputedBuffer(
"buf0", None, torch.float32, (3, 4, 1), (4, 1, 0)
)
col_major_buf1 = MockComputedBuffer(
"buf1", None, torch.float32, (3, 2, 1), (1, 3, 0)
)
buffer_renames = {"buf0": "buf0", "buf1": "buf1", "acc": "buf0"}
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
result = create_example_tensors(
buffer_renames, name_to_buffer, lambda x: int(x)
)
self.assertEqual(result["acc"].shape, (3, 4, 1))
self.assertEqual(result["acc"].stride, (4, 1, 0))
self.assertEqual(
result["acc"].element, torch_dtype_to_cutlass_type(torch.float32)
)
self.assertEqual(result["buf1"].shape, (3, 2, 1))
self.assertEqual(result["buf1"].stride, (1, 3, 0))
self.assertEqual(
result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32)
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_evt_argument_codegen(self):
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch
cuda_arch = int(get_cuda_arch()) # type: ignore[arg-type]
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS, cuda_arch)
self.assertExpectedInline(
_render_argument_type(
epilogue_functor,
_create_mock_buffer_name_map(EXAMPLE_TENSORS),
lambda x: int(x),
)[0],
"""\
{ /* thread */
{ /* F */
{ /* compute_1 */
{ /* compute_0 */
{}, /* accum */
{}, /* C */
{}, /* compute_0 */
},
{/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */
{}, /* compute_1 */
},
{/* ptr_aux */ (float*) (ptr_1 + ptr_1_offset), /* dAux */ {2048, _1{}, _0{}}}, /* F */
},
{/* ptr_col */ (float*) (ptr_2 + ptr_2_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */
{}, /* compute_2 */
{}, /* compute_3 */
{}, /* compute_4 */
}
""",
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_evt_argument_codegen_return_accumulator(self):
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch
code = """
def fn(accum, bias):
E = accum
D = E + bias
return D, E
"""
example_tensors = {
"accum": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
"bias": BIAS,
# "beta": 0.5, TODO: mlazos support scalars
# "alpha": 0.5, TODO: mlazos support scalars
"D": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
"E": CutlassTensor(
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
),
}
cuda_arch = int(get_cuda_arch()) # type: ignore[arg-type]
epilogue_functor = _trace(code, example_tensors, cuda_arch)
self.assertExpectedInline(
_render_argument_type(
epilogue_functor,
_create_mock_buffer_name_map(example_tensors),
lambda x: int(x),
)[0],
"""\
{ /* thread */
{ /* E */
{}, /* accum */
{/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* dAux */ {2048, _1{}, _0{}}}, /* E */
},
{/* ptr_col */ (float*) (ptr_1 + ptr_1_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */
{}, /* compute_0 */
}
""",
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_evt_codegen(self):
_, _, code, _ = trace(
BIAS_CODE,
EXAMPLE_TENSORS,
DataType.f32,
DataType.f32,
MockTileDescription(),
EpilogueScheduleType.ScheduleAuto,
_create_mock_buffer_name_map(EXAMPLE_TENSORS),
lambda x: x, # static shapes
)
self.assertExpectedInline(
code,
"""\
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
cute::Shape<_128, _128, _8>, cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::epilogue::collective::EpilogueScheduleAuto
>;
using ElementC = float;
using StrideC = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>;
using TensorC = cutlass::epilogue::fusion::Sm90SrcFetch<float>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using AuxDescriptor = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, \
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float>;
using Aux = cutlass::epilogue::fusion::Sm90AuxLoad<
AuxDescriptor::Stages, typename AuxDescriptor::EpilogueTile, float,
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename AuxDescriptor::SmemLayoutAtom, \
typename AuxDescriptor::CopyOpS2R
>;
using Bias = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, float,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>
>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<
Compute0,
Accum,
TensorC>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<
Compute1,
EVTCompute0,
Aux>;
using FDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float
>;
using F = cutlass::epilogue::fusion::Sm90AuxStore<
FDescriptor::Stages, typename FDescriptor::EpilogueTile, float,
cutlass::FloatRoundStyle::round_to_nearest, cute::Stride<int64_t, cute::Int<1>, \
cute::Int<0>>, typename FDescriptor::SmemLayoutAtom,
typename FDescriptor::CopyOpR2S
>;
using EVTF = cutlass::epilogue::fusion::Sm90EVT<
F,
EVTCompute1>;
using Compute2 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::ReLu, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using Compute3 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using Compute4 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus, float, float,
cutlass::FloatRoundStyle::round_to_nearest
>;
using DagCompute4 = cutlass::epilogue::fusion::Sm90TopologicalVisitor<
float,
cute::tuple<
cute::seq<>,
cute::seq<>,
cute::seq<0>,
cute::seq<2, 1>,
cute::seq<3, 0>,
>,
EVTF,
Bias,
Compute2,
Compute3,
Compute4
>;
using ElementD = float;
using StrideD = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>;
""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
if HAS_CPU or HAS_CUDA_AND_TRITON:
run_tests(needs="filelock")