mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cutlass backend] Global filter ops before situation based filter ops (#157866)
The idea of this PR is that, sometimes we are filtering ops based not based on the node specific information. For example, we always filter out simt ops. So I want to group them together into a global filtering function. This can help shrink the config space as well. 20s -> 6s for instantiation 3332. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157866 Approved by: https://github.com/ColinPeppler
This commit is contained in:
parent
2a8795a981
commit
ff7dd1776f
|
|
@ -15,6 +15,7 @@ import sympy
|
|||
import torch
|
||||
from torch._inductor.runtime.runtime_utils import dynamo_timed
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ... import config
|
||||
from ...ir import Layout
|
||||
|
|
@ -27,6 +28,10 @@ from .cuda_env import get_cuda_arch, get_cuda_version
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
CUTLASS_OPERATION_KIND: str = "gemm"
|
||||
ACCUMULATOR_DTYPES: OrderedSet[torch.dtype] = OrderedSet([torch.float, torch.int32])
|
||||
XW_DTYPES: OrderedSet[torch.dtype] = OrderedSet(
|
||||
[torch.half, torch.bfloat16, torch.float8_e4m3fn, torch.int8]
|
||||
)
|
||||
|
||||
|
||||
@atexit.register
|
||||
|
|
@ -348,6 +353,10 @@ def get_accumulator_dtype(
|
|||
Given a pair of input torch dtypes, returns the inferred accumulator torch dtype.
|
||||
"""
|
||||
|
||||
assert OrderedSet(input_torch_dtypes) <= XW_DTYPES, (
|
||||
f"{input_torch_dtypes=} is not supported"
|
||||
)
|
||||
|
||||
if len(input_torch_dtypes) != 2:
|
||||
return None
|
||||
|
||||
|
|
@ -368,10 +377,16 @@ def get_accumulator_dtype(
|
|||
torch_dtype = dtype0
|
||||
|
||||
if torch_dtype in (torch.float16, torch.bfloat16, torch.float, torch.float8_e4m3fn):
|
||||
return torch.float
|
||||
if torch_dtype == torch.int8:
|
||||
return torch.int32
|
||||
raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}")
|
||||
accumulator_dtype = torch.float
|
||||
elif torch_dtype == torch.int8:
|
||||
accumulator_dtype = torch.int32
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}")
|
||||
|
||||
assert accumulator_dtype in ACCUMULATOR_DTYPES, (
|
||||
f"{accumulator_dtype=} is not supported"
|
||||
)
|
||||
return accumulator_dtype
|
||||
|
||||
|
||||
@functools.lru_cache(32)
|
||||
|
|
|
|||
|
|
@ -35,7 +35,12 @@ from .cuda_kernel import CUDATemplateKernel
|
|||
from .cuda_template import CUTLASSTemplate
|
||||
from .cutlass_presets import gen_cutlass_presets
|
||||
from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt
|
||||
from .cutlass_utils import torch_dtype_to_cutlass_type
|
||||
from .cutlass_utils import (
|
||||
ACCUMULATOR_DTYPES,
|
||||
dtype_match,
|
||||
torch_dtype_to_cutlass_type,
|
||||
XW_DTYPES,
|
||||
)
|
||||
|
||||
|
||||
GemmOperation = Any
|
||||
|
|
@ -659,6 +664,15 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
"""Helper Method: Determines whether a given torch layout matches a given Cutlass layout"""
|
||||
return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout
|
||||
|
||||
@staticmethod
|
||||
def set_layout(tensor_desc: "TensorDescription", torch_layout: ir.Layout) -> None: # type: ignore[name-defined] # noqa: F821
|
||||
"""
|
||||
Helper method: Sets the layout of a given tensor description to match the given torch layout
|
||||
"""
|
||||
if CUTLASSGemmTemplate.layout_match(torch_layout, tensor_desc.layout):
|
||||
return
|
||||
tensor_desc.layout = CUTLASSGemmTemplate.cutlass_layout(torch_layout)
|
||||
|
||||
@staticmethod
|
||||
def set_alignment(torch_layout, op_element) -> bool:
|
||||
"""
|
||||
|
|
@ -800,6 +814,53 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def global_filter_ops(
|
||||
cls,
|
||||
ops: list["cutlass_library.gemm_op.GemmOperation"], # type: ignore[name-defined] # noqa: F821
|
||||
) -> list["cutlass_library.gemm_op.GemmOperation"]: # type: ignore[name-defined] # noqa: F821
|
||||
"""
|
||||
Filter ops without using information about the torch op, input nodes and output node.
|
||||
"""
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
|
||||
# Skip simt kernels
|
||||
ops = [
|
||||
op
|
||||
for op in ops
|
||||
if op.tile_description.math_instruction.opcode_class
|
||||
!= cutlass_lib.OpcodeClass.Simt
|
||||
]
|
||||
|
||||
# only keep the set of row x column ops
|
||||
# for other layout, we modify in place in filter_op, after deepcopy
|
||||
ops = [
|
||||
op
|
||||
for op in ops
|
||||
if op.A.layout.name == "RowMajor" and op.B.layout.name == "ColumnMajor"
|
||||
]
|
||||
|
||||
# filter by supported accumulator types
|
||||
ops = [
|
||||
op
|
||||
for op in ops
|
||||
if any(
|
||||
dtype_match(torch_dtype, op.accumulator_type())
|
||||
for torch_dtype in ACCUMULATOR_DTYPES
|
||||
)
|
||||
]
|
||||
|
||||
# check if dtypes of A and B are supported
|
||||
ops = [
|
||||
op
|
||||
for op in ops
|
||||
if any(dtype_match(torch_dtype, op.A.element) for torch_dtype in XW_DTYPES)
|
||||
and any(dtype_match(torch_dtype, op.B.element) for torch_dtype in XW_DTYPES)
|
||||
]
|
||||
|
||||
return ops
|
||||
|
||||
def filter_op(
|
||||
self,
|
||||
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
|
||||
|
|
@ -817,16 +878,6 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
have been mutated.
|
||||
"""
|
||||
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
|
||||
# Skip simt kernels
|
||||
if (
|
||||
op.tile_description.math_instruction.opcode_class
|
||||
== cutlass_lib.OpcodeClass.Simt
|
||||
):
|
||||
return None
|
||||
|
||||
if op.gemm_kind not in self._get_supported_ops():
|
||||
return None
|
||||
|
||||
|
|
@ -841,13 +892,6 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
if not self._dtype_match(op):
|
||||
return None
|
||||
|
||||
# Filter ops by input layouts.
|
||||
if not (
|
||||
self.layout_match(X.get_layout(), op.A.layout)
|
||||
and self.layout_match(W.get_layout(), op.B.layout)
|
||||
):
|
||||
return None
|
||||
|
||||
# Filter ops by alignment.
|
||||
if not self._alignment_match(op):
|
||||
log.debug(
|
||||
|
|
@ -858,6 +902,10 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
# Update op.
|
||||
op = copy.deepcopy(op)
|
||||
|
||||
# set layouts for X and W
|
||||
self.set_layout(op.A, X.get_layout())
|
||||
self.set_layout(op.B, W.get_layout())
|
||||
|
||||
# Set output layout.
|
||||
op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout())
|
||||
|
||||
|
|
@ -954,6 +1002,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
log.debug("Using cached ops from cache")
|
||||
ops = maybe_ops
|
||||
|
||||
ops = self.global_filter_ops(ops)
|
||||
|
||||
res: dict[str, cutlass_gemm_op.GemmOperation] = {}
|
||||
start_time = time.time()
|
||||
for op in ops:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user