mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cutlass backend][ez] Ban FP32 output dtype from using CUTLASS GEMM backend (#151279)
FP32 not supported: https://github.com/pytorch/pytorch/issues/145952 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151279 Approved by: https://github.com/ColinPeppler
This commit is contained in:
parent
8780d18f64
commit
532025fbd0
|
|
@ -1324,6 +1324,12 @@ def use_max_autotune() -> bool:
|
|||
def _use_template_for_gpu(
|
||||
layout: Layout, allowed_layout_dtypes: list[torch.dtype]
|
||||
) -> bool:
|
||||
if layout.dtype not in allowed_layout_dtypes:
|
||||
log.debug(
|
||||
"Not using template since dtype %s is not in allowed layout dtypes %s",
|
||||
layout.dtype,
|
||||
allowed_layout_dtypes,
|
||||
)
|
||||
return (
|
||||
is_gpu(layout.device.type)
|
||||
and layout.dtype in allowed_layout_dtypes
|
||||
|
|
@ -1416,7 +1422,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
|||
if torch.version.hip:
|
||||
return False
|
||||
|
||||
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
|
||||
# output dtype
|
||||
# FP32 not supported: https://github.com/pytorch/pytorch/issues/145952
|
||||
layout_dtypes = [torch.float16, torch.bfloat16, torch.int32]
|
||||
res = (
|
||||
_use_template_for_gpu(layout, layout_dtypes)
|
||||
and use_max_autotune()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user