mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Cutlass] Implement EVT example tensor creation (#150904)
This PR implements a translation layer from inductor IR to "example tensors" the expected arguments of the EVT tracer. These tensors basically store the name, shape, stride, and dtype of the tensor and allow an ast-based python parse to generate the EVT C++. udpates to example tensor creation Previously merged: * https://github.com/pytorch/pytorch/pull/150903 * https://github.com/pytorch/pytorch/pull/150346 * https://github.com/pytorch/pytorch/pull/150345 * https://github.com/pytorch/pytorch/pull/150344 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150904 Approved by: https://github.com/eellison
This commit is contained in:
parent
dda0c952e7
commit
a936d596f6
|
|
@ -3,7 +3,10 @@ import unittest
|
|||
|
||||
import torch
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass
|
||||
from torch._inductor.codegen.cuda.cutlass_utils import (
|
||||
torch_dtype_to_cutlass_type,
|
||||
try_import_cutlass,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
|
||||
|
||||
|
|
@ -55,33 +58,64 @@ if try_import_cutlass():
|
|||
class MockTileDescription:
|
||||
threadblock_shape = (128, 128, 8)
|
||||
|
||||
class MockNode:
|
||||
def __init__(self, name, shape, stride, dtype):
|
||||
self.name = name
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
self.stride = stride
|
||||
|
||||
def get_layout(self):
|
||||
class MockLayout:
|
||||
def __init__(self, shape, stride, dtype):
|
||||
self.size = shape
|
||||
self.stride = stride
|
||||
self.dtype = dtype
|
||||
|
||||
return MockLayout(self.shape, self.stride, self.dtype)
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
|
||||
def _create_mock_buffer_name_map(example_tensors):
|
||||
class MockNode:
|
||||
def __init__(self, name, stride, dtype):
|
||||
self.name = name
|
||||
self.dtype = dtype
|
||||
self.stride = stride
|
||||
|
||||
def get_layout(self):
|
||||
class MockLayout:
|
||||
def __init__(self, stride, dtype):
|
||||
self.dtype = dtype
|
||||
self.stride = stride
|
||||
|
||||
return MockLayout(self.stride, self.dtype)
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
|
||||
name_to_buffer = {}
|
||||
for name, tensor in example_tensors.items():
|
||||
if isinstance(tensor, CutlassTensor):
|
||||
name_to_buffer[name] = MockNode(name, tensor.stride, torch.float32)
|
||||
name_to_buffer[name] = MockNode(
|
||||
name, tensor.shape, tensor.stride, torch.float32
|
||||
)
|
||||
|
||||
return name_to_buffer
|
||||
|
||||
|
||||
class TestCutlassEVT(TestCase):
|
||||
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
||||
def test_example_tensor_creation(self):
|
||||
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
|
||||
create_example_tensors,
|
||||
)
|
||||
|
||||
row_major_buf0 = MockNode("buf0", (3, 4, 1), (4, 1, 0), torch.float32)
|
||||
col_major_buf1 = MockNode("buf1", (3, 2, 1), (1, 3, 0), torch.float32)
|
||||
read_names = ["buf0"]
|
||||
write_names = ["buf1"]
|
||||
buffer_renames = {"buf0": "acc"}
|
||||
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
|
||||
result = create_example_tensors(
|
||||
read_names, write_names, buffer_renames, name_to_buffer
|
||||
)
|
||||
self.assertEqual(result["acc"].shape, (3, 4, 1))
|
||||
self.assertEqual(result["acc"].stride, (4, 1, 0))
|
||||
self.assertEqual(
|
||||
result["acc"].element, torch_dtype_to_cutlass_type(torch.float32)
|
||||
)
|
||||
|
||||
self.assertEqual(result["buf1"].shape, (3, 2, 1))
|
||||
self.assertEqual(result["buf1"].stride, (1, 3, 0))
|
||||
self.assertEqual(
|
||||
result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32)
|
||||
)
|
||||
|
||||
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
||||
def test_evt_argument_codegen(self):
|
||||
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from torch._inductor.ir import ComputedBuffer, InputBuffer
|
||||
from torch._inductor.ir import (
|
||||
ComputedBuffer,
|
||||
InputBuffer,
|
||||
is_contiguous_strides_for_shape,
|
||||
)
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ..cutlass_utils import try_import_cutlass
|
||||
from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass
|
||||
|
||||
|
||||
EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace
|
||||
|
|
@ -19,6 +23,7 @@ if try_import_cutlass():
|
|||
import ast
|
||||
import ctypes
|
||||
import textwrap
|
||||
from typing import Union
|
||||
|
||||
from cutlass.backend.c_types import ( # type: ignore[import-untyped, import-not-found]
|
||||
EmptyByte,
|
||||
|
|
@ -41,13 +46,74 @@ if try_import_cutlass():
|
|||
from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found]
|
||||
Tensor as CutlassTensor,
|
||||
)
|
||||
from cutlass_library import DataType, EpilogueScheduleType, TileDescription
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
EpilogueScheduleType,
|
||||
LayoutType,
|
||||
TileDescription,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch._inductor.codegen.cuda import cuda_env
|
||||
from torch._inductor.utils import IndentedBuffer
|
||||
|
||||
_CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated]
|
||||
|
||||
TORCH_TO_CUTLASS_DTYPE = {
|
||||
torch.float32: DataType.f32,
|
||||
torch.float16: DataType.f16,
|
||||
torch.bfloat16: DataType.bf16,
|
||||
}
|
||||
|
||||
def create_example_tensors(
|
||||
read_names: list[str],
|
||||
write_names: list[str],
|
||||
buffer_renames: dict[str, str],
|
||||
name_to_buffer: dict[str, Buffer],
|
||||
) -> dict[str, CutlassTensor]:
|
||||
example_tensors = {}
|
||||
|
||||
def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor:
|
||||
shape = buffer.get_layout().size
|
||||
stride = buffer.get_layout().stride
|
||||
assert all(isinstance(x, int) for x in buffer.get_layout().stride), (
|
||||
f"{buffer.get_name()}'s shape {shape} contains symints which aren't supported for cutlass EVT"
|
||||
)
|
||||
assert all(isinstance(x, int) for x in buffer.get_layout().stride), (
|
||||
f"{buffer.get_name()}'s stride {stride} contains symints which aren't supported for cutlass EVT"
|
||||
)
|
||||
shape = tuple(int(x) for x in shape)
|
||||
stride = tuple(int(x) for x in stride)
|
||||
|
||||
is_row_major = is_contiguous_strides_for_shape(stride, shape)
|
||||
is_column_major = is_contiguous_strides_for_shape(stride[::-1], shape[::-1])
|
||||
|
||||
if not is_row_major and not is_column_major:
|
||||
raise RuntimeError(
|
||||
f"Cannot create example tensor for {buffer.get_name()} with \
|
||||
non-contiguous layout, recieved stride: {stride} and shape: {shape}"
|
||||
)
|
||||
|
||||
return CutlassTensor(
|
||||
shape=shape,
|
||||
layout_tag=LayoutType.RowMajor
|
||||
if is_row_major
|
||||
else LayoutType.ColumnMajor,
|
||||
element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype),
|
||||
)
|
||||
|
||||
for name in read_names + write_names:
|
||||
key = name
|
||||
|
||||
if name in buffer_renames:
|
||||
key = buffer_renames[
|
||||
name
|
||||
] # Need to rewrite some special args (e.g. acc is a required arg name)
|
||||
|
||||
example_tensors[key] = cutlass_tensor_from_buffer(name_to_buffer[name])
|
||||
|
||||
return example_tensors
|
||||
|
||||
def trace(
|
||||
fn_src: str,
|
||||
example_tensors: dict[str, CutlassTensor],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user