mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cutlass backend][ez] Ban FP32 output dtype from using CUTLASS GEMM backend (#151279)
FP32 not supported: https://github.com/pytorch/pytorch/issues/145952 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151279 Approved by: https://github.com/ColinPeppler
This commit is contained in:
parent
8780d18f64
commit
532025fbd0
|
|
@ -1324,6 +1324,12 @@ def use_max_autotune() -> bool:
|
||||||
def _use_template_for_gpu(
|
def _use_template_for_gpu(
|
||||||
layout: Layout, allowed_layout_dtypes: list[torch.dtype]
|
layout: Layout, allowed_layout_dtypes: list[torch.dtype]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
if layout.dtype not in allowed_layout_dtypes:
|
||||||
|
log.debug(
|
||||||
|
"Not using template since dtype %s is not in allowed layout dtypes %s",
|
||||||
|
layout.dtype,
|
||||||
|
allowed_layout_dtypes,
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
is_gpu(layout.device.type)
|
is_gpu(layout.device.type)
|
||||||
and layout.dtype in allowed_layout_dtypes
|
and layout.dtype in allowed_layout_dtypes
|
||||||
|
|
@ -1416,7 +1422,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||||
if torch.version.hip:
|
if torch.version.hip:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
|
# output dtype
|
||||||
|
# FP32 not supported: https://github.com/pytorch/pytorch/issues/145952
|
||||||
|
layout_dtypes = [torch.float16, torch.bfloat16, torch.int32]
|
||||||
res = (
|
res = (
|
||||||
_use_template_for_gpu(layout, layout_dtypes)
|
_use_template_for_gpu(layout, layout_dtypes)
|
||||||
and use_max_autotune()
|
and use_max_autotune()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user