[cutlass backend][ez] Ban FP32 output dtype from using CUTLASS GEMM backend (#151279)

FP32 not supported: https://github.com/pytorch/pytorch/issues/145952

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151279
Approved by: https://github.com/ColinPeppler
This commit is contained in:
henrylhtsang 2025-04-15 13:41:28 -07:00 committed by PyTorch MergeBot
parent 8780d18f64
commit 532025fbd0

View File

@ -1324,6 +1324,12 @@ def use_max_autotune() -> bool:
def _use_template_for_gpu(
layout: Layout, allowed_layout_dtypes: list[torch.dtype]
) -> bool:
if layout.dtype not in allowed_layout_dtypes:
log.debug(
"Not using template since dtype %s is not in allowed layout dtypes %s",
layout.dtype,
allowed_layout_dtypes,
)
return (
is_gpu(layout.device.type)
and layout.dtype in allowed_layout_dtypes
@ -1416,7 +1422,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
if torch.version.hip:
return False
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
# output dtype
# FP32 not supported: https://github.com/pytorch/pytorch/issues/145952
layout_dtypes = [torch.float16, torch.bfloat16, torch.int32]
res = (
_use_template_for_gpu(layout, layout_dtypes)
and use_max_autotune()