# Owner(s): ["module: inductor"] import unittest import sympy import torch from torch._dynamo.test_case import TestCase from torch._inductor.codegen.cuda.cutlass_utils import ( torch_dtype_to_cutlass_type, try_import_cutlass, ) from torch._inductor.ir import ComputedBuffer, FixedLayout, PermuteView, Pointwise from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.utils import OrderedSet from torch.testing._internal.common_cuda import SM90OrLater from torch.testing._internal.inductor_utils import ( HAS_CPU, HAS_CUDA_AND_TRITON, MockGraphHandler, ) if try_import_cutlass(): import cutlass_library as cutlass_lib from cutlass_library import EpilogueScheduleType LayoutType = cutlass_lib.LayoutType DataType = cutlass_lib.DataType from cutlass_cppgen.backend.evt.ir.tensor import Tensor as CutlassTensor from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import ( _render_argument_type, _trace, trace, ) BIAS_CODE = """def example_epilogue(accum, C, aux, bias): F = accum + C + aux E = relu(F) + bias D = E + F return D, F""" TYPE_C = DataType.f32 M = 4224 N = 2048 BIAS = CutlassTensor(shape=(M, 1), element=TYPE_C, layout_tag=LayoutType.RowMajor) EXAMPLE_TENSORS = { "accum": CutlassTensor( element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor ), "bias": BIAS, # "beta": 0.5, TODO: mlazos support scalars # "alpha": 0.5, TODO: mlazos support scalars "D": CutlassTensor( element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor ), "C": CutlassTensor( element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor ), "F": CutlassTensor( element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor ), "aux": CutlassTensor( element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor ), } class MockTileDescription: threadblock_shape = (128, 128, 8) def _create_mock_buffer_name_map(example_tensors): name_to_buffer = {} for name, tensor in example_tensors.items(): if isinstance(tensor, CutlassTensor): name_to_buffer[name] = MockComputedBuffer( name, None, torch.float32, tensor.shape, tensor.stride ) return name_to_buffer class MockSchedulerNode(BaseSchedulerNode): def __init__(self, node, last_usage=None): self.node = node self.last_usage = last_usage or OrderedSet() class MockComputedBuffer(ComputedBuffer): def __init__(self, name, inner_fn, dtype, size, strides=None): self.name = name ranges = [sympy.Integer(x) for x in size] self.data = Pointwise( device=None, dtype=dtype, inner_fn=inner_fn, ranges=ranges ) self.layout = FixedLayout(None, dtype, ranges, strides) def get_name(self): return self.name def num_reads(self): # Needed to not inline in ComputedBuffer return 1 class TestCutlassEVT(TestCase): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen_accumulator_return(self): from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen from torch._inductor.virtualized import V size = (100, 300, 200) buf0 = MockComputedBuffer("buf0", None, torch.float32, size) buf1 = MockComputedBuffer("buf1", None, torch.float32, size) buf2 = MockComputedBuffer("buf2", None, torch.float32, size) # buf0 is acc # buf1 is external def inner_fn_buf3(index): tmp0 = buf0.make_loader()(index) tmp1 = buf1.make_loader()(index) tmp2 = buf2.make_loader()(index) return tmp0 * tmp1 + tmp2 def inner_fn_buf4(index): tmp0 = buf0.make_loader()(index) tmp3 = buf3.make_loader()(index) return tmp0 + tmp3 buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size) buf4 = MockComputedBuffer("buf4", inner_fn_buf4, torch.float32, size) with V.set_graph_handler( MockGraphHandler( {"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4} ) ): reads, writes, renames, code = CutlassEVTCodegen.ir_to_evt_python_code( "buf0", [ MockSchedulerNode(buf3), MockSchedulerNode(buf4, last_usage=OrderedSet(["buf3"])), ], OrderedSet([]), ) self.assertExpectedInline(reads, """['buf1', 'buf2']""") self.assertExpectedInline(writes, """['buf0', 'buf3', 'buf4']""") self.assertExpectedInline( renames, """{'accum': 'buf0', 'tmp_0': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'tmp_2': 'buf3', 'D': 'buf4'}""", ) self.assertExpectedInline( code, """\ def fn(accum, buf1, buf2): tmp_0 = accum tmp_1 = tmp_0 * buf1 tmp_2 = tmp_1 + buf2 D = tmp_0 + tmp_2 return tmp_0, tmp_2, D""", ) @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen_disjoint_read_indexing(self): from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen from torch._inductor.virtualized import V size = (100, 300, 200) buf0 = MockComputedBuffer("buf0", None, torch.float32, size) permuted_buf_0 = PermuteView.create(buf0, [1, 0, 2]) buf1 = MockComputedBuffer("buf1", None, torch.float32, size) buf2 = MockComputedBuffer("buf2", None, torch.float32, size) # buf0 is acc # buf1 is external def inner_fn_buf3(index): tmp0 = permuted_buf_0.make_loader()(index) tmp1 = buf1.make_loader()(index) tmp2 = buf2.make_loader()(index) return tmp0 * tmp1 + tmp2 def inner_fn_buf4(index): tmp0 = buf0.make_loader()(index) tmp3 = buf3.make_loader()(index) return tmp0 + tmp3 buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size) buf4 = MockComputedBuffer("buf4", inner_fn_buf4, torch.float32, size) with V.set_graph_handler( MockGraphHandler( {"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4} ) ): result = None try: CutlassEVTCodegen.ir_to_evt_python_code( "buf0", [MockSchedulerNode(buf3), MockSchedulerNode(buf4)], OrderedSet([]), ) except NotImplementedError as e: result = e self.assertExpectedInline( str(result), """Unsupported indexing for buf0 with index 200*i0 + 60000*i1 + i2, \ index strides [200, 60000, 1], and layout stride [60000, 200, 1]""", ) @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen_broadcasting(self): from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen from torch._inductor.virtualized import V size = (100, 300, 200) buf0 = MockComputedBuffer("buf0", None, torch.float32, size) buf1 = MockComputedBuffer("buf1", None, torch.float32, size) buf2 = MockComputedBuffer("buf2", None, torch.float32, size) # buf0 is acc # buf1 is external def inner_fn_buf3(index): tmp0 = buf0.make_loader()(index) tmp1 = buf1.make_loader()(index) tmp2 = buf2.make_loader()(index) return tmp0 * tmp1 + tmp2 def inner_fn_buf4(index): tmp0 = buf0.make_loader()(index) tmp3 = buf3.make_loader()(index) return tmp0 + tmp3 * tmp3 buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size) buf4 = MockComputedBuffer( "buf4", inner_fn_buf4, torch.float32, (100, 300, 1) ) # broadcast with V.set_graph_handler( MockGraphHandler( {"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4} ) ): reads, writes, renames, code = CutlassEVTCodegen.ir_to_evt_python_code( "buf0", [ MockSchedulerNode(buf3), MockSchedulerNode(buf4, last_usage=OrderedSet(["buf0"])), ], OrderedSet([]), ) self.assertExpectedInline(reads, """['buf1', 'buf2']""") self.assertExpectedInline(writes, """['buf0', 'buf3', 'buf4']""") self.assertExpectedInline( renames, """{'accum': 'buf0', 'tmp_0': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'tmp_2': 'buf3', 'D': 'buf4'}""", ) self.assertExpectedInline( code, """\ def fn(accum, buf1, buf2): tmp_0 = accum tmp_1 = tmp_0 * buf1 tmp_2 = tmp_1 + buf2 tmp_3 = tmp_2 * tmp_2 D = tmp_0 + tmp_3 return tmp_0, tmp_2, D""", ) @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen(self): from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen from torch._inductor.virtualized import V size = (100, 300, 200) buf0 = MockComputedBuffer("buf0", None, torch.float32, size) buf1 = MockComputedBuffer("buf1", None, torch.float32, size) buf2 = MockComputedBuffer("buf2", None, torch.float32, size) # buf0 is acc # buf1 is external def inner_fn_buf3(index): tmp0 = buf0.make_loader()(index) tmp1 = buf1.make_loader()(index) tmp2 = buf2.make_loader()(index) return tmp0 * tmp1 + tmp2 def inner_fn_buf4(index): tmp0 = buf0.make_loader()(index) tmp3 = buf3.make_loader()(index) return tmp0 + tmp3 buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size) buf4 = MockComputedBuffer("buf4", inner_fn_buf4, torch.float32, size) with V.set_graph_handler( MockGraphHandler( {"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4} ) ): reads, writes, renames, code = CutlassEVTCodegen.ir_to_evt_python_code( "buf0", [ MockSchedulerNode(buf3), MockSchedulerNode(buf4), ], OrderedSet(["buf0"]), ) self.assertExpectedInline(reads, """['buf1', 'buf2']""") self.assertExpectedInline(writes, """['buf3', 'buf4']""") self.assertExpectedInline( renames, """{'accum': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'tmp_1': 'buf3', 'D': 'buf4'}""", ) self.assertExpectedInline( code, """\ def fn(accum, buf1, buf2): tmp_0 = accum * buf1 tmp_1 = tmp_0 + buf2 D = accum + tmp_1 return tmp_1, D""", ) @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_example_tensor_creation(self): from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import ( create_example_tensors, ) from torch._inductor.virtualized import V with V.set_graph_handler(MockGraphHandler({})): row_major_buf0 = MockComputedBuffer( "buf0", None, torch.float32, (3, 4, 1), (4, 1, 0) ) col_major_buf1 = MockComputedBuffer( "buf1", None, torch.float32, (3, 2, 1), (1, 3, 0) ) buffer_renames = {"buf0": "buf0", "buf1": "buf1", "acc": "buf0"} name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1} result = create_example_tensors( buffer_renames, name_to_buffer, lambda x: int(x) ) self.assertEqual(result["acc"].shape, (3, 4, 1)) self.assertEqual(result["acc"].stride, (4, 1, 0)) self.assertEqual( result["acc"].element, torch_dtype_to_cutlass_type(torch.float32) ) self.assertEqual(result["buf1"].shape, (3, 2, 1)) self.assertEqual(result["buf1"].stride, (1, 3, 0)) self.assertEqual( result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32) ) @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_evt_argument_codegen(self): from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch cuda_arch = int(get_cuda_arch()) # type: ignore[arg-type] epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS, cuda_arch) self.assertExpectedInline( _render_argument_type( epilogue_functor, _create_mock_buffer_name_map(EXAMPLE_TENSORS), lambda x: int(x), )[0], """\ { /* thread */ { /* F */ { /* compute_1 */ { /* compute_0 */ {}, /* accum */ {}, /* C */ {}, /* compute_0 */ }, {/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */ {}, /* compute_1 */ }, {/* ptr_aux */ (float*) (ptr_1 + ptr_1_offset), /* dAux */ {2048, _1{}, _0{}}}, /* F */ }, {/* ptr_col */ (float*) (ptr_2 + ptr_2_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */ {}, /* compute_2 */ {}, /* compute_3 */ {}, /* compute_4 */ } """, ) @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_evt_argument_codegen_return_accumulator(self): from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch code = """ def fn(accum, bias): E = accum D = E + bias return D, E """ example_tensors = { "accum": CutlassTensor( element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor ), "bias": BIAS, # "beta": 0.5, TODO: mlazos support scalars # "alpha": 0.5, TODO: mlazos support scalars "D": CutlassTensor( element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor ), "E": CutlassTensor( element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor ), } cuda_arch = int(get_cuda_arch()) # type: ignore[arg-type] epilogue_functor = _trace(code, example_tensors, cuda_arch) self.assertExpectedInline( _render_argument_type( epilogue_functor, _create_mock_buffer_name_map(example_tensors), lambda x: int(x), )[0], """\ { /* thread */ { /* E */ {}, /* accum */ {/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* dAux */ {2048, _1{}, _0{}}}, /* E */ }, {/* ptr_col */ (float*) (ptr_1 + ptr_1_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */ {}, /* compute_0 */ } """, ) @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_evt_codegen(self): _, _, code, _ = trace( BIAS_CODE, EXAMPLE_TENSORS, DataType.f32, DataType.f32, MockTileDescription(), EpilogueScheduleType.ScheduleAuto, _create_mock_buffer_name_map(EXAMPLE_TENSORS), lambda x: x, # static shapes ) self.assertExpectedInline( code, """\ using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< cute::Shape<_128, _128, _8>, cutlass::epilogue::collective::EpilogueTileAuto, float, float, cutlass::epilogue::collective::EpilogueScheduleAuto >; using ElementC = float; using StrideC = cute::Stride, cute::Int<0>>; using TensorC = cutlass::epilogue::fusion::Sm90SrcFetch; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using AuxDescriptor = cutlass::epilogue::collective::detail::AuxLoadDescriptor, cute::Int<0>>, float>; using Aux = cutlass::epilogue::fusion::Sm90AuxLoad< AuxDescriptor::Stages, typename AuxDescriptor::EpilogueTile, float, cute::Stride, cute::Int<0>>, typename AuxDescriptor::SmemLayoutAtom, \ typename AuxDescriptor::CopyOpS2R >; using Bias = cutlass::epilogue::fusion::Sm90ColBroadcast< 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, float, cute::Stride, cute::Int<0>, cute::Int<0>> >; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< cutlass::plus, float, float, cutlass::FloatRoundStyle::round_to_nearest >; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT< Compute0, Accum, TensorC>; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< cutlass::plus, float, float, cutlass::FloatRoundStyle::round_to_nearest >; using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT< Compute1, EVTCompute0, Aux>; using FDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< EpilogueDescriptor, cute::Stride, cute::Int<0>>, float >; using F = cutlass::epilogue::fusion::Sm90AuxStore< FDescriptor::Stages, typename FDescriptor::EpilogueTile, float, cutlass::FloatRoundStyle::round_to_nearest, cute::Stride, \ cute::Int<0>>, typename FDescriptor::SmemLayoutAtom, typename FDescriptor::CopyOpR2S >; using EVTF = cutlass::epilogue::fusion::Sm90EVT< F, EVTCompute1>; using Compute2 = cutlass::epilogue::fusion::Sm90Compute< cutlass::epilogue::thread::ReLu, float, float, cutlass::FloatRoundStyle::round_to_nearest >; using Compute3 = cutlass::epilogue::fusion::Sm90Compute< cutlass::plus, float, float, cutlass::FloatRoundStyle::round_to_nearest >; using Compute4 = cutlass::epilogue::fusion::Sm90Compute< cutlass::plus, float, float, cutlass::FloatRoundStyle::round_to_nearest >; using DagCompute4 = cutlass::epilogue::fusion::Sm90TopologicalVisitor< float, cute::tuple< cute::seq<>, cute::seq<>, cute::seq<0>, cute::seq<2, 1>, cute::seq<3, 0>, >, EVTF, Bias, Compute2, Compute3, Compute4 >; using ElementD = float; using StrideD = cute::Stride, cute::Int<0>>; """, ) if __name__ == "__main__": from torch._dynamo.test_case import run_tests if HAS_CPU or HAS_CUDA_AND_TRITON: run_tests(needs="filelock")