[cutlass backend] Forward fix for less aligned gemm shapes (#148521)

Differential Revision: [D70600093](https://our.internmc.facebook.com/intern/diff/D70600093/)

1. Check if config name filtering still works.
Tested, it works

2. do we get C++ compile error
Yes, potentially we need to filter them out manually.

Here we get this.
```
static_assert(threads_minor == 0 || (TileSizeK % threads_minor == 0));
```
We need to move some assertions to gemm_template.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148521
Approved by: https://github.com/ColinPeppler
This commit is contained in:
henrylhtsang 2025-03-07 15:54:16 -08:00 committed by PyTorch MergeBot
parent aac230a511
commit b47d81682d
2 changed files with 107 additions and 44 deletions

View File

@ -898,7 +898,9 @@ class TestCutlassBackend(TestCase):
"torch._inductor.kernel.mm.autotune_select_algorithm",
wraps=select_no_algorithm,
) as sa:
with self.assertRaises(InductorError):
with self.assertRaisesRegex(
InductorError, r".*NoValidChoicesError.*"
):
torch.compile(my_addmm, dynamic=False)(x, a, b, 1.0, 2.0)
args, _ = sa.call_args
op_name, choices, _, __ = args
@ -944,7 +946,9 @@ class TestCutlassBackend(TestCase):
"torch._inductor.kernel.mm.autotune_select_algorithm",
wraps=select_no_algorithm,
) as sa:
with self.assertRaises(InductorError):
with self.assertRaisesRegex(
InductorError, r".*NoValidChoicesError.*"
):
torch.compile(addmm, dynamic=False)(x, a, b, 1.0, 1.0)
args, _ = sa.call_args
op_name, choices, _, __ = args
@ -961,6 +965,80 @@ class TestCutlassBackend(TestCase):
cuda_template_count += 1
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_cutlass_backend_shape_coverage_mm(
self,
):
"""
Checks if cutlass backend produces some ops for a variety of shapes.
This test doesn't compile and check the correctness of the ops.
NOTE: K has to be even.
"""
inputs = [
(torch.randn(128, 500).cuda().half(), torch.randn(500, 576).cuda().half()),
(
torch.randn(500, 128).cuda().half(),
torch.randn(128, 576).cuda().half(),
),
(torch.randn(128, 250).cuda().half(), torch.randn(250, 576).cuda().half()),
(
torch.randn(250, 128).cuda().half(),
torch.randn(128, 576).cuda().half(),
),
(
torch.randn(125, 128).cuda().half(),
torch.randn(128, 576).cuda().half(),
),
]
def select_no_algorithm(*args, **kwargs):
raise NoValidChoicesError
with fresh_inductor_cache(), config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 2,
"autotune_fallback_to_aten": False,
}
), mock.patch(
"torch._inductor.kernel.mm.autotune_select_algorithm",
wraps=select_no_algorithm,
) as sa:
for input in inputs:
A, B = input
M, K = A.shape
_, N = B.shape
with self.assertRaisesRegex(InductorError, r".*NoValidChoicesError.*"):
torch.compile(torch.mm, dynamic=False)(*input)
self.assertTrue(
sa.called,
f"autotune_select_algorithm was not called with shape M={M}, N={N}, K={K}",
)
args, _ = sa.call_args
op_name, choices, _, __ = args
assert op_name == "mm"
cuda_template_count = 0
for choice in choices:
if isinstance(choice, CUDATemplateCaller):
choice_info = choice.info_dict()
op_conf_name = choice_info.get("op_conf_name", "")
assert isinstance(op_conf_name, str)
cuda_template_count += 1
self.assertGreater(
cuda_template_count,
0,
"No CUDATemplateCaller choices found for matmul with shape "
f"M={M}, N={N}, K={K}",
)
@unittest.skipIf(not SM80OrLater, "need sm_80")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_get_max_alignment(self):

View File

@ -841,6 +841,16 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
# Set epilogue.
# TODO: update epilogue functor according to epilogues.
op.element_epilogue = op.accumulator_type()
# Set bias layout and alignment.
status = self._set_bias_layout_and_alignment(op)
if not status:
log.debug(
"Skipping due to bias layout and alignment setting failure. op: %s", op
)
return None
# Apply regex filters at the end when configuration name doesn't change anymore
if inductor_cuda_config.cutlass_op_allowlist_regex is not None:
if not re.search(
inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name()
@ -852,14 +862,6 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
):
return None
# Set bias layout and alignment.
status = self._set_bias_layout_and_alignment(op)
if not status:
log.debug(
"Skipping due to bias layout and alignment setting failure. op: %s", op
)
return None
return op
def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: ignore[name-defined] # noqa: F821
@ -1212,46 +1214,29 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> bool:
import cutlass_library.library as cutlass_lib
has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None
if has_bias:
bias = self.input_nodes[2]
bias_layout = CUTLASSGemmTemplate.cutlass_layout(bias.get_layout())
Bias = self.input_nodes[2]
# bias dtype
op.C.element = cutlass_utils.torch_dtype_to_cutlass_type(
Bias.get_layout().dtype
)
assert op.C.element == op.D.element, (
f"Expect C and D to have the same dtype, found {op.C.element} and {op.D.element}"
)
# Bias layout
bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout())
op.C.layout = bias_layout
status = self.set_alignment(bias.get_layout(), op.C)
# Bias alignment
status = self.set_alignment(Bias.get_layout(), op.C)
if not status:
return False
return True
def _dtype_match(
self,
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
) -> bool:
"""
Checking dtypes of C (i.e. bias) here, since that is the one not checked in the base class.
"""
if not super()._dtype_match(op):
return False
assert cutlass_utils.try_import_cutlass()
from cutlass_library.library import DataType # type: ignore[import]
has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None
if op.C.element == DataType.void:
if has_bias:
# op expects no bias, but bias exists
return False
else:
# op expects bias. Needs to check if bias exists and is of the right dtype
if not (
has_bias
and cutlass_utils.dtype_match(
self.input_nodes[2].get_dtype(), op.C.element
)
):
return False
op.C.element = cutlass_lib.DataType.void
return True
def _define_gemm_instance(