[cutlass backend] Global filter ops before situation based filter ops (#157866)

The idea of this PR is that, sometimes we are filtering ops based not based on the node specific information. For example, we always filter out simt ops. So I want to group them together into a global filtering function.

This can help shrink the config space as well. 20s -> 6s for instantiation 3332.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157866
Approved by: https://github.com/ColinPeppler
This commit is contained in:
henrylhtsang 2025-07-11 10:08:36 -07:00 committed by PyTorch MergeBot
parent 2a8795a981
commit ff7dd1776f
2 changed files with 87 additions and 22 deletions

View File

@ -15,6 +15,7 @@ import sympy
import torch
from torch._inductor.runtime.runtime_utils import dynamo_timed
from torch._inductor.utils import clear_on_fresh_cache
from torch.utils._ordered_set import OrderedSet
from ... import config
from ...ir import Layout
@ -27,6 +28,10 @@ from .cuda_env import get_cuda_arch, get_cuda_version
log = logging.getLogger(__name__)
CUTLASS_OPERATION_KIND: str = "gemm"
ACCUMULATOR_DTYPES: OrderedSet[torch.dtype] = OrderedSet([torch.float, torch.int32])
XW_DTYPES: OrderedSet[torch.dtype] = OrderedSet(
[torch.half, torch.bfloat16, torch.float8_e4m3fn, torch.int8]
)
@atexit.register
@ -348,6 +353,10 @@ def get_accumulator_dtype(
Given a pair of input torch dtypes, returns the inferred accumulator torch dtype.
"""
assert OrderedSet(input_torch_dtypes) <= XW_DTYPES, (
f"{input_torch_dtypes=} is not supported"
)
if len(input_torch_dtypes) != 2:
return None
@ -368,10 +377,16 @@ def get_accumulator_dtype(
torch_dtype = dtype0
if torch_dtype in (torch.float16, torch.bfloat16, torch.float, torch.float8_e4m3fn):
return torch.float
if torch_dtype == torch.int8:
return torch.int32
raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}")
accumulator_dtype = torch.float
elif torch_dtype == torch.int8:
accumulator_dtype = torch.int32
else:
raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}")
assert accumulator_dtype in ACCUMULATOR_DTYPES, (
f"{accumulator_dtype=} is not supported"
)
return accumulator_dtype
@functools.lru_cache(32)

View File

@ -35,7 +35,12 @@ from .cuda_kernel import CUDATemplateKernel
from .cuda_template import CUTLASSTemplate
from .cutlass_presets import gen_cutlass_presets
from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt
from .cutlass_utils import torch_dtype_to_cutlass_type
from .cutlass_utils import (
ACCUMULATOR_DTYPES,
dtype_match,
torch_dtype_to_cutlass_type,
XW_DTYPES,
)
GemmOperation = Any
@ -659,6 +664,15 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
"""Helper Method: Determines whether a given torch layout matches a given Cutlass layout"""
return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout
@staticmethod
def set_layout(tensor_desc: "TensorDescription", torch_layout: ir.Layout) -> None: # type: ignore[name-defined] # noqa: F821
"""
Helper method: Sets the layout of a given tensor description to match the given torch layout
"""
if CUTLASSGemmTemplate.layout_match(torch_layout, tensor_desc.layout):
return
tensor_desc.layout = CUTLASSGemmTemplate.cutlass_layout(torch_layout)
@staticmethod
def set_alignment(torch_layout, op_element) -> bool:
"""
@ -800,6 +814,53 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
return True
@classmethod
def global_filter_ops(
cls,
ops: list["cutlass_library.gemm_op.GemmOperation"], # type: ignore[name-defined] # noqa: F821
) -> list["cutlass_library.gemm_op.GemmOperation"]: # type: ignore[name-defined] # noqa: F821
"""
Filter ops without using information about the torch op, input nodes and output node.
"""
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib # type: ignore[import]
# Skip simt kernels
ops = [
op
for op in ops
if op.tile_description.math_instruction.opcode_class
!= cutlass_lib.OpcodeClass.Simt
]
# only keep the set of row x column ops
# for other layout, we modify in place in filter_op, after deepcopy
ops = [
op
for op in ops
if op.A.layout.name == "RowMajor" and op.B.layout.name == "ColumnMajor"
]
# filter by supported accumulator types
ops = [
op
for op in ops
if any(
dtype_match(torch_dtype, op.accumulator_type())
for torch_dtype in ACCUMULATOR_DTYPES
)
]
# check if dtypes of A and B are supported
ops = [
op
for op in ops
if any(dtype_match(torch_dtype, op.A.element) for torch_dtype in XW_DTYPES)
and any(dtype_match(torch_dtype, op.B.element) for torch_dtype in XW_DTYPES)
]
return ops
def filter_op(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
@ -817,16 +878,6 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
have been mutated.
"""
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib # type: ignore[import]
# Skip simt kernels
if (
op.tile_description.math_instruction.opcode_class
== cutlass_lib.OpcodeClass.Simt
):
return None
if op.gemm_kind not in self._get_supported_ops():
return None
@ -841,13 +892,6 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
if not self._dtype_match(op):
return None
# Filter ops by input layouts.
if not (
self.layout_match(X.get_layout(), op.A.layout)
and self.layout_match(W.get_layout(), op.B.layout)
):
return None
# Filter ops by alignment.
if not self._alignment_match(op):
log.debug(
@ -858,6 +902,10 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
# Update op.
op = copy.deepcopy(op)
# set layouts for X and W
self.set_layout(op.A, X.get_layout())
self.set_layout(op.B, W.get_layout())
# Set output layout.
op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout())
@ -954,6 +1002,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
log.debug("Using cached ops from cache")
ops = maybe_ops
ops = self.global_filter_ops(ops)
res: dict[str, cutlass_gemm_op.GemmOperation] = {}
start_time = time.time()
for op in ops: