[Cutlass] Implement EVT example tensor creation (#150904)

This PR implements a translation layer from inductor IR to "example tensors" the expected arguments of the EVT tracer. These tensors basically store the name, shape, stride, and dtype of the tensor and allow an ast-based python parse to generate the EVT C++.

udpates to example tensor creation

Previously merged:
* https://github.com/pytorch/pytorch/pull/150903
* https://github.com/pytorch/pytorch/pull/150346
* https://github.com/pytorch/pytorch/pull/150345
* https://github.com/pytorch/pytorch/pull/150344

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150904
Approved by: https://github.com/eellison
This commit is contained in:
Michael Lazos 2025-04-24 17:09:09 -07:00 committed by PyTorch MergeBot
parent dda0c952e7
commit a936d596f6
3 changed files with 122 additions and 22 deletions

View File

@ -3,7 +3,10 @@ import unittest
import torch
from torch._dynamo.test_case import TestCase
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
from torch._inductor.codegen.cuda.cutlass_utils import (
torch_dtype_to_cutlass_type,
try_import_cutlass,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
@ -55,33 +58,64 @@ if try_import_cutlass():
class MockTileDescription:
threadblock_shape = (128, 128, 8)
class MockNode:
def __init__(self, name, shape, stride, dtype):
self.name = name
self.dtype = dtype
self.shape = shape
self.stride = stride
def get_layout(self):
class MockLayout:
def __init__(self, shape, stride, dtype):
self.size = shape
self.stride = stride
self.dtype = dtype
return MockLayout(self.shape, self.stride, self.dtype)
def get_name(self):
return self.name
def _create_mock_buffer_name_map(example_tensors):
class MockNode:
def __init__(self, name, stride, dtype):
self.name = name
self.dtype = dtype
self.stride = stride
def get_layout(self):
class MockLayout:
def __init__(self, stride, dtype):
self.dtype = dtype
self.stride = stride
return MockLayout(self.stride, self.dtype)
def get_name(self):
return self.name
name_to_buffer = {}
for name, tensor in example_tensors.items():
if isinstance(tensor, CutlassTensor):
name_to_buffer[name] = MockNode(name, tensor.stride, torch.float32)
name_to_buffer[name] = MockNode(
name, tensor.shape, tensor.stride, torch.float32
)
return name_to_buffer
class TestCutlassEVT(TestCase):
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_example_tensor_creation(self):
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
create_example_tensors,
)
row_major_buf0 = MockNode("buf0", (3, 4, 1), (4, 1, 0), torch.float32)
col_major_buf1 = MockNode("buf1", (3, 2, 1), (1, 3, 0), torch.float32)
read_names = ["buf0"]
write_names = ["buf1"]
buffer_renames = {"buf0": "acc"}
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
result = create_example_tensors(
read_names, write_names, buffer_renames, name_to_buffer
)
self.assertEqual(result["acc"].shape, (3, 4, 1))
self.assertEqual(result["acc"].stride, (4, 1, 0))
self.assertEqual(
result["acc"].element, torch_dtype_to_cutlass_type(torch.float32)
)
self.assertEqual(result["buf1"].shape, (3, 2, 1))
self.assertEqual(result["buf1"].stride, (1, 3, 0))
self.assertEqual(
result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32)
)
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
def test_evt_argument_codegen(self):
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS)

View File

@ -1,9 +1,13 @@
from typing import Any, Union
from torch._inductor.ir import ComputedBuffer, InputBuffer
from torch._inductor.ir import (
ComputedBuffer,
InputBuffer,
is_contiguous_strides_for_shape,
)
from torch.utils._ordered_set import OrderedSet
from ..cutlass_utils import try_import_cutlass
from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass
EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace
@ -19,6 +23,7 @@ if try_import_cutlass():
import ast
import ctypes
import textwrap
from typing import Union
from cutlass.backend.c_types import ( # type: ignore[import-untyped, import-not-found]
EmptyByte,
@ -41,13 +46,74 @@ if try_import_cutlass():
from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found]
Tensor as CutlassTensor,
)
from cutlass_library import DataType, EpilogueScheduleType, TileDescription
from cutlass_library import (
DataType,
EpilogueScheduleType,
LayoutType,
TileDescription,
)
import torch
from torch._inductor.codegen.cuda import cuda_env
from torch._inductor.utils import IndentedBuffer
_CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated]
TORCH_TO_CUTLASS_DTYPE = {
torch.float32: DataType.f32,
torch.float16: DataType.f16,
torch.bfloat16: DataType.bf16,
}
def create_example_tensors(
read_names: list[str],
write_names: list[str],
buffer_renames: dict[str, str],
name_to_buffer: dict[str, Buffer],
) -> dict[str, CutlassTensor]:
example_tensors = {}
def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor:
shape = buffer.get_layout().size
stride = buffer.get_layout().stride
assert all(isinstance(x, int) for x in buffer.get_layout().stride), (
f"{buffer.get_name()}'s shape {shape} contains symints which aren't supported for cutlass EVT"
)
assert all(isinstance(x, int) for x in buffer.get_layout().stride), (
f"{buffer.get_name()}'s stride {stride} contains symints which aren't supported for cutlass EVT"
)
shape = tuple(int(x) for x in shape)
stride = tuple(int(x) for x in stride)
is_row_major = is_contiguous_strides_for_shape(stride, shape)
is_column_major = is_contiguous_strides_for_shape(stride[::-1], shape[::-1])
if not is_row_major and not is_column_major:
raise RuntimeError(
f"Cannot create example tensor for {buffer.get_name()} with \
non-contiguous layout, recieved stride: {stride} and shape: {shape}"
)
return CutlassTensor(
shape=shape,
layout_tag=LayoutType.RowMajor
if is_row_major
else LayoutType.ColumnMajor,
element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype),
)
for name in read_names + write_names:
key = name
if name in buffer_renames:
key = buffer_renames[
name
] # Need to rewrite some special args (e.g. acc is a required arg name)
example_tensors[key] = cutlass_tensor_from_buffer(name_to_buffer[name])
return example_tensors
def trace(
fn_src: str,
example_tensors: dict[str, CutlassTensor],