mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Differential Revision: D83809105 Handle reviews from https://github.com/pytorch/pytorch/pull/164159 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164589 Approved by: https://github.com/Skylion007
565 lines
19 KiB
Python
565 lines
19 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import unittest
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch._dynamo.test_case import TestCase
|
|
from torch._inductor.codegen.cuda.cutlass_utils import (
|
|
torch_dtype_to_cutlass_type,
|
|
try_import_cutlass,
|
|
)
|
|
from torch._inductor.ir import ComputedBuffer, FixedLayout, PermuteView, Pointwise
|
|
from torch._inductor.scheduler import BaseSchedulerNode
|
|
from torch._inductor.utils import OrderedSet
|
|
from torch.testing._internal.common_cuda import SM90OrLater
|
|
from torch.testing._internal.inductor_utils import (
|
|
HAS_CPU,
|
|
HAS_CUDA_AND_TRITON,
|
|
MockGraphHandler,
|
|
)
|
|
|
|
|
|
if try_import_cutlass():
|
|
import cutlass_library as cutlass_lib
|
|
from cutlass_library import EpilogueScheduleType
|
|
|
|
LayoutType = cutlass_lib.LayoutType
|
|
DataType = cutlass_lib.DataType
|
|
from cutlass_cppgen.backend.evt.ir.tensor import Tensor as CutlassTensor
|
|
|
|
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
|
|
_render_argument_type,
|
|
_trace,
|
|
trace,
|
|
)
|
|
|
|
BIAS_CODE = """def example_epilogue(accum, C, aux, bias):
|
|
F = accum + C + aux
|
|
E = relu(F) + bias
|
|
D = E + F
|
|
return D, F"""
|
|
|
|
TYPE_C = DataType.f32
|
|
M = 4224
|
|
N = 2048
|
|
BIAS = CutlassTensor(shape=(M, 1), element=TYPE_C, layout_tag=LayoutType.RowMajor)
|
|
|
|
EXAMPLE_TENSORS = {
|
|
"accum": CutlassTensor(
|
|
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
|
),
|
|
"bias": BIAS,
|
|
# "beta": 0.5, TODO: mlazos support scalars
|
|
# "alpha": 0.5, TODO: mlazos support scalars
|
|
"D": CutlassTensor(
|
|
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
|
),
|
|
"C": CutlassTensor(
|
|
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
|
),
|
|
"F": CutlassTensor(
|
|
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
|
),
|
|
"aux": CutlassTensor(
|
|
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
|
),
|
|
}
|
|
|
|
class MockTileDescription:
|
|
threadblock_shape = (128, 128, 8)
|
|
|
|
def _create_mock_buffer_name_map(example_tensors):
|
|
name_to_buffer = {}
|
|
for name, tensor in example_tensors.items():
|
|
if isinstance(tensor, CutlassTensor):
|
|
name_to_buffer[name] = MockComputedBuffer(
|
|
name, None, torch.float32, tensor.shape, tensor.stride
|
|
)
|
|
|
|
return name_to_buffer
|
|
|
|
|
|
class MockSchedulerNode(BaseSchedulerNode):
|
|
def __init__(self, node, last_usage=None):
|
|
self.node = node
|
|
self.last_usage = last_usage or OrderedSet()
|
|
|
|
|
|
class MockComputedBuffer(ComputedBuffer):
|
|
def __init__(self, name, inner_fn, dtype, size, strides=None):
|
|
self.name = name
|
|
ranges = [sympy.Integer(x) for x in size]
|
|
self.data = Pointwise(
|
|
device=None, dtype=dtype, inner_fn=inner_fn, ranges=ranges
|
|
)
|
|
self.layout = FixedLayout(None, dtype, ranges, strides)
|
|
|
|
def get_name(self):
|
|
return self.name
|
|
|
|
def num_reads(self):
|
|
# Needed to not inline in ComputedBuffer
|
|
return 1
|
|
|
|
|
|
class TestCutlassEVT(TestCase):
|
|
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
|
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
|
def test_py_codegen_accumulator_return(self):
|
|
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
|
|
from torch._inductor.virtualized import V
|
|
|
|
size = (100, 300, 200)
|
|
buf0 = MockComputedBuffer("buf0", None, torch.float32, size)
|
|
buf1 = MockComputedBuffer("buf1", None, torch.float32, size)
|
|
buf2 = MockComputedBuffer("buf2", None, torch.float32, size)
|
|
|
|
# buf0 is acc
|
|
# buf1 is external
|
|
def inner_fn_buf3(index):
|
|
tmp0 = buf0.make_loader()(index)
|
|
tmp1 = buf1.make_loader()(index)
|
|
tmp2 = buf2.make_loader()(index)
|
|
return tmp0 * tmp1 + tmp2
|
|
|
|
def inner_fn_buf4(index):
|
|
tmp0 = buf0.make_loader()(index)
|
|
tmp3 = buf3.make_loader()(index)
|
|
return tmp0 + tmp3
|
|
|
|
buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size)
|
|
buf4 = MockComputedBuffer("buf4", inner_fn_buf4, torch.float32, size)
|
|
with V.set_graph_handler(
|
|
MockGraphHandler(
|
|
{"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4}
|
|
)
|
|
):
|
|
reads, writes, renames, code = CutlassEVTCodegen.ir_to_evt_python_code(
|
|
"buf0",
|
|
[
|
|
MockSchedulerNode(buf3),
|
|
MockSchedulerNode(buf4, last_usage=OrderedSet(["buf3"])),
|
|
],
|
|
OrderedSet([]),
|
|
)
|
|
self.assertExpectedInline(reads, """['buf1', 'buf2']""")
|
|
self.assertExpectedInline(writes, """['buf0', 'buf3', 'buf4']""")
|
|
self.assertExpectedInline(
|
|
renames,
|
|
"""{'accum': 'buf0', 'tmp_0': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'tmp_2': 'buf3', 'D': 'buf4'}""",
|
|
)
|
|
self.assertExpectedInline(
|
|
code,
|
|
"""\
|
|
def fn(accum, buf1, buf2):
|
|
tmp_0 = accum
|
|
tmp_1 = tmp_0 * buf1
|
|
tmp_2 = tmp_1 + buf2
|
|
D = tmp_0 + tmp_2
|
|
|
|
return tmp_0, tmp_2, D""",
|
|
)
|
|
|
|
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
|
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
|
def test_py_codegen_disjoint_read_indexing(self):
|
|
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
|
|
from torch._inductor.virtualized import V
|
|
|
|
size = (100, 300, 200)
|
|
buf0 = MockComputedBuffer("buf0", None, torch.float32, size)
|
|
permuted_buf_0 = PermuteView.create(buf0, [1, 0, 2])
|
|
buf1 = MockComputedBuffer("buf1", None, torch.float32, size)
|
|
buf2 = MockComputedBuffer("buf2", None, torch.float32, size)
|
|
|
|
# buf0 is acc
|
|
# buf1 is external
|
|
def inner_fn_buf3(index):
|
|
tmp0 = permuted_buf_0.make_loader()(index)
|
|
tmp1 = buf1.make_loader()(index)
|
|
tmp2 = buf2.make_loader()(index)
|
|
return tmp0 * tmp1 + tmp2
|
|
|
|
def inner_fn_buf4(index):
|
|
tmp0 = buf0.make_loader()(index)
|
|
tmp3 = buf3.make_loader()(index)
|
|
return tmp0 + tmp3
|
|
|
|
buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size)
|
|
buf4 = MockComputedBuffer("buf4", inner_fn_buf4, torch.float32, size)
|
|
|
|
with V.set_graph_handler(
|
|
MockGraphHandler(
|
|
{"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4}
|
|
)
|
|
):
|
|
result = None
|
|
try:
|
|
CutlassEVTCodegen.ir_to_evt_python_code(
|
|
"buf0",
|
|
[MockSchedulerNode(buf3), MockSchedulerNode(buf4)],
|
|
OrderedSet([]),
|
|
)
|
|
except NotImplementedError as e:
|
|
result = e
|
|
|
|
self.assertExpectedInline(
|
|
str(result),
|
|
"""Unsupported indexing for buf0 with index 200*i0 + 60000*i1 + i2, \
|
|
index strides [200, 60000, 1], and layout stride [60000, 200, 1]""",
|
|
)
|
|
|
|
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
|
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
|
def test_py_codegen_broadcasting(self):
|
|
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
|
|
from torch._inductor.virtualized import V
|
|
|
|
size = (100, 300, 200)
|
|
buf0 = MockComputedBuffer("buf0", None, torch.float32, size)
|
|
buf1 = MockComputedBuffer("buf1", None, torch.float32, size)
|
|
buf2 = MockComputedBuffer("buf2", None, torch.float32, size)
|
|
|
|
# buf0 is acc
|
|
# buf1 is external
|
|
def inner_fn_buf3(index):
|
|
tmp0 = buf0.make_loader()(index)
|
|
tmp1 = buf1.make_loader()(index)
|
|
tmp2 = buf2.make_loader()(index)
|
|
return tmp0 * tmp1 + tmp2
|
|
|
|
def inner_fn_buf4(index):
|
|
tmp0 = buf0.make_loader()(index)
|
|
tmp3 = buf3.make_loader()(index)
|
|
return tmp0 + tmp3 * tmp3
|
|
|
|
buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size)
|
|
buf4 = MockComputedBuffer(
|
|
"buf4", inner_fn_buf4, torch.float32, (100, 300, 1)
|
|
) # broadcast
|
|
with V.set_graph_handler(
|
|
MockGraphHandler(
|
|
{"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4}
|
|
)
|
|
):
|
|
reads, writes, renames, code = CutlassEVTCodegen.ir_to_evt_python_code(
|
|
"buf0",
|
|
[
|
|
MockSchedulerNode(buf3),
|
|
MockSchedulerNode(buf4, last_usage=OrderedSet(["buf0"])),
|
|
],
|
|
OrderedSet([]),
|
|
)
|
|
self.assertExpectedInline(reads, """['buf1', 'buf2']""")
|
|
self.assertExpectedInline(writes, """['buf0', 'buf3', 'buf4']""")
|
|
self.assertExpectedInline(
|
|
renames,
|
|
"""{'accum': 'buf0', 'tmp_0': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'tmp_2': 'buf3', 'D': 'buf4'}""",
|
|
)
|
|
self.assertExpectedInline(
|
|
code,
|
|
"""\
|
|
def fn(accum, buf1, buf2):
|
|
tmp_0 = accum
|
|
tmp_1 = tmp_0 * buf1
|
|
tmp_2 = tmp_1 + buf2
|
|
tmp_3 = tmp_2 * tmp_2
|
|
D = tmp_0 + tmp_3
|
|
|
|
return tmp_0, tmp_2, D""",
|
|
)
|
|
|
|
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
|
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
|
def test_py_codegen(self):
|
|
from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen
|
|
from torch._inductor.virtualized import V
|
|
|
|
size = (100, 300, 200)
|
|
buf0 = MockComputedBuffer("buf0", None, torch.float32, size)
|
|
buf1 = MockComputedBuffer("buf1", None, torch.float32, size)
|
|
buf2 = MockComputedBuffer("buf2", None, torch.float32, size)
|
|
|
|
# buf0 is acc
|
|
# buf1 is external
|
|
def inner_fn_buf3(index):
|
|
tmp0 = buf0.make_loader()(index)
|
|
tmp1 = buf1.make_loader()(index)
|
|
tmp2 = buf2.make_loader()(index)
|
|
return tmp0 * tmp1 + tmp2
|
|
|
|
def inner_fn_buf4(index):
|
|
tmp0 = buf0.make_loader()(index)
|
|
tmp3 = buf3.make_loader()(index)
|
|
return tmp0 + tmp3
|
|
|
|
buf3 = MockComputedBuffer("buf3", inner_fn_buf3, torch.float32, size)
|
|
buf4 = MockComputedBuffer("buf4", inner_fn_buf4, torch.float32, size)
|
|
with V.set_graph_handler(
|
|
MockGraphHandler(
|
|
{"buf0": buf0, "buf1": buf1, "buf2": buf2, "buf3": buf3, "buf4": buf4}
|
|
)
|
|
):
|
|
reads, writes, renames, code = CutlassEVTCodegen.ir_to_evt_python_code(
|
|
"buf0",
|
|
[
|
|
MockSchedulerNode(buf3),
|
|
MockSchedulerNode(buf4),
|
|
],
|
|
OrderedSet(["buf0"]),
|
|
)
|
|
self.assertExpectedInline(reads, """['buf1', 'buf2']""")
|
|
self.assertExpectedInline(writes, """['buf3', 'buf4']""")
|
|
self.assertExpectedInline(
|
|
renames,
|
|
"""{'accum': 'buf0', 'buf1': 'buf1', 'buf2': 'buf2', 'tmp_1': 'buf3', 'D': 'buf4'}""",
|
|
)
|
|
self.assertExpectedInline(
|
|
code,
|
|
"""\
|
|
def fn(accum, buf1, buf2):
|
|
tmp_0 = accum * buf1
|
|
tmp_1 = tmp_0 + buf2
|
|
D = accum + tmp_1
|
|
|
|
return tmp_1, D""",
|
|
)
|
|
|
|
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
|
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
|
def test_example_tensor_creation(self):
|
|
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
|
|
create_example_tensors,
|
|
)
|
|
from torch._inductor.virtualized import V
|
|
|
|
with V.set_graph_handler(MockGraphHandler({})):
|
|
row_major_buf0 = MockComputedBuffer(
|
|
"buf0", None, torch.float32, (3, 4, 1), (4, 1, 0)
|
|
)
|
|
col_major_buf1 = MockComputedBuffer(
|
|
"buf1", None, torch.float32, (3, 2, 1), (1, 3, 0)
|
|
)
|
|
buffer_renames = {"buf0": "buf0", "buf1": "buf1", "acc": "buf0"}
|
|
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
|
|
result = create_example_tensors(
|
|
buffer_renames, name_to_buffer, lambda x: int(x)
|
|
)
|
|
self.assertEqual(result["acc"].shape, (3, 4, 1))
|
|
self.assertEqual(result["acc"].stride, (4, 1, 0))
|
|
self.assertEqual(
|
|
result["acc"].element, torch_dtype_to_cutlass_type(torch.float32)
|
|
)
|
|
|
|
self.assertEqual(result["buf1"].shape, (3, 2, 1))
|
|
self.assertEqual(result["buf1"].stride, (1, 3, 0))
|
|
self.assertEqual(
|
|
result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32)
|
|
)
|
|
|
|
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
|
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
|
def test_evt_argument_codegen(self):
|
|
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch
|
|
|
|
cuda_arch = int(get_cuda_arch()) # type: ignore[arg-type]
|
|
epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS, cuda_arch)
|
|
|
|
self.assertExpectedInline(
|
|
_render_argument_type(
|
|
epilogue_functor,
|
|
_create_mock_buffer_name_map(EXAMPLE_TENSORS),
|
|
lambda x: int(x),
|
|
)[0],
|
|
"""\
|
|
{ /* thread */
|
|
{ /* F */
|
|
{ /* compute_1 */
|
|
{ /* compute_0 */
|
|
{}, /* accum */
|
|
{}, /* C */
|
|
{}, /* compute_0 */
|
|
},
|
|
{/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */
|
|
{}, /* compute_1 */
|
|
},
|
|
{/* ptr_aux */ (float*) (ptr_1 + ptr_1_offset), /* dAux */ {2048, _1{}, _0{}}}, /* F */
|
|
},
|
|
{/* ptr_col */ (float*) (ptr_2 + ptr_2_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */
|
|
{}, /* compute_2 */
|
|
{}, /* compute_3 */
|
|
{}, /* compute_4 */
|
|
}
|
|
""",
|
|
)
|
|
|
|
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
|
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
|
def test_evt_argument_codegen_return_accumulator(self):
|
|
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch
|
|
|
|
code = """
|
|
def fn(accum, bias):
|
|
E = accum
|
|
D = E + bias
|
|
return D, E
|
|
"""
|
|
example_tensors = {
|
|
"accum": CutlassTensor(
|
|
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
|
),
|
|
"bias": BIAS,
|
|
# "beta": 0.5, TODO: mlazos support scalars
|
|
# "alpha": 0.5, TODO: mlazos support scalars
|
|
"D": CutlassTensor(
|
|
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
|
),
|
|
"E": CutlassTensor(
|
|
element=DataType.f32, shape=(M, N), layout_tag=LayoutType.RowMajor
|
|
),
|
|
}
|
|
|
|
cuda_arch = int(get_cuda_arch()) # type: ignore[arg-type]
|
|
epilogue_functor = _trace(code, example_tensors, cuda_arch)
|
|
|
|
self.assertExpectedInline(
|
|
_render_argument_type(
|
|
epilogue_functor,
|
|
_create_mock_buffer_name_map(example_tensors),
|
|
lambda x: int(x),
|
|
)[0],
|
|
"""\
|
|
{ /* thread */
|
|
{ /* E */
|
|
{}, /* accum */
|
|
{/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* dAux */ {2048, _1{}, _0{}}}, /* E */
|
|
},
|
|
{/* ptr_col */ (float*) (ptr_1 + ptr_1_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */
|
|
{}, /* compute_0 */
|
|
}
|
|
""",
|
|
)
|
|
|
|
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
|
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
|
def test_evt_codegen(self):
|
|
_, _, code, _ = trace(
|
|
BIAS_CODE,
|
|
EXAMPLE_TENSORS,
|
|
DataType.f32,
|
|
DataType.f32,
|
|
MockTileDescription(),
|
|
EpilogueScheduleType.ScheduleAuto,
|
|
_create_mock_buffer_name_map(EXAMPLE_TENSORS),
|
|
lambda x: x, # static shapes
|
|
)
|
|
self.assertExpectedInline(
|
|
code,
|
|
"""\
|
|
|
|
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
|
cute::Shape<_128, _128, _8>, cutlass::epilogue::collective::EpilogueTileAuto,
|
|
float, float,
|
|
cutlass::epilogue::collective::EpilogueScheduleAuto
|
|
>;
|
|
|
|
using ElementC = float;
|
|
using StrideC = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>;
|
|
using TensorC = cutlass::epilogue::fusion::Sm90SrcFetch<float>;
|
|
|
|
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
|
|
|
using AuxDescriptor = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, \
|
|
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float>;
|
|
|
|
using Aux = cutlass::epilogue::fusion::Sm90AuxLoad<
|
|
AuxDescriptor::Stages, typename AuxDescriptor::EpilogueTile, float,
|
|
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, typename AuxDescriptor::SmemLayoutAtom, \
|
|
typename AuxDescriptor::CopyOpS2R
|
|
>;
|
|
|
|
using Bias = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, float,
|
|
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>
|
|
>;
|
|
|
|
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
|
cutlass::plus, float, float,
|
|
cutlass::FloatRoundStyle::round_to_nearest
|
|
>;
|
|
|
|
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<
|
|
Compute0,
|
|
Accum,
|
|
TensorC>;
|
|
|
|
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
|
cutlass::plus, float, float,
|
|
cutlass::FloatRoundStyle::round_to_nearest
|
|
>;
|
|
|
|
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<
|
|
Compute1,
|
|
EVTCompute0,
|
|
Aux>;
|
|
|
|
using FDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
|
EpilogueDescriptor, cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>, float
|
|
>;
|
|
|
|
using F = cutlass::epilogue::fusion::Sm90AuxStore<
|
|
FDescriptor::Stages, typename FDescriptor::EpilogueTile, float,
|
|
cutlass::FloatRoundStyle::round_to_nearest, cute::Stride<int64_t, cute::Int<1>, \
|
|
cute::Int<0>>, typename FDescriptor::SmemLayoutAtom,
|
|
typename FDescriptor::CopyOpR2S
|
|
>;
|
|
|
|
using EVTF = cutlass::epilogue::fusion::Sm90EVT<
|
|
F,
|
|
EVTCompute1>;
|
|
|
|
using Compute2 = cutlass::epilogue::fusion::Sm90Compute<
|
|
cutlass::epilogue::thread::ReLu, float, float,
|
|
cutlass::FloatRoundStyle::round_to_nearest
|
|
>;
|
|
|
|
using Compute3 = cutlass::epilogue::fusion::Sm90Compute<
|
|
cutlass::plus, float, float,
|
|
cutlass::FloatRoundStyle::round_to_nearest
|
|
>;
|
|
|
|
using Compute4 = cutlass::epilogue::fusion::Sm90Compute<
|
|
cutlass::plus, float, float,
|
|
cutlass::FloatRoundStyle::round_to_nearest
|
|
>;
|
|
|
|
using DagCompute4 = cutlass::epilogue::fusion::Sm90TopologicalVisitor<
|
|
float,
|
|
cute::tuple<
|
|
cute::seq<>,
|
|
cute::seq<>,
|
|
cute::seq<0>,
|
|
cute::seq<2, 1>,
|
|
cute::seq<3, 0>,
|
|
>,
|
|
EVTF,
|
|
Bias,
|
|
Compute2,
|
|
Compute3,
|
|
Compute4
|
|
>;
|
|
|
|
using ElementD = float;
|
|
using StrideD = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>;
|
|
|
|
""",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
if HAS_CPU or HAS_CUDA_AND_TRITON:
|
|
run_tests(needs="filelock")
|