[cutlass backend] check against arch >= 100 (#145812)

Summary:
Want to add a guard against silent fallback to SM90.

GenerateSM100 was just added 3 days ago. https://github.com/NVIDIA/cutlass/blame/main/python/cutlass_library/generator.py#L8896

It should show up in CUTLASS 3.8 (not pinned yet).

Test Plan: ci

Differential Revision: D68748705

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145812
Approved by: https://github.com/chenyang78, https://github.com/ColinPeppler, https://github.com/Aidyn-A
This commit is contained in:
Henry Tsang 2025-02-10 22:41:08 +00:00 committed by PyTorch MergeBot
parent bab35eb26a
commit 0d5fb0941f

View File

@ -114,8 +114,19 @@ def try_import_cutlass() -> bool:
return False
@functools.lru_cache(8)
def _normalize_cuda_arch(arch: str) -> str:
if int(arch) >= 90:
if int(arch) >= 100:
log.warning(
"Detected CUDA architecture >= 100: %s. We will generate operations with "
"GenerateSM100 (if available) and GenerateSM90. Please file an "
"issue for any problems and feedback. ",
arch,
)
if int(arch) >= 100:
return "100"
elif int(arch) >= 90:
return "90"
elif int(arch) >= 80:
return "80"
@ -186,7 +197,15 @@ def _gen_ops_cached(arch, version) -> list[Any]:
)
manifest = cutlass_manifest.Manifest(args)
if arch == "90":
if arch == "100":
try:
from cutlass_generator import GenerateSM100 # type: ignore[import]
GenerateSM100(manifest, args.cuda_version)
except ImportError:
log.warning("Cannot find GenerateSM100. Only GenerateSM90 will be used. ")
cutlass_generator.GenerateSM90(manifest, args.cuda_version)
elif arch == "90":
cutlass_generator.GenerateSM90(manifest, args.cuda_version)
cutlass_generator.GenerateSM80(manifest, args.cuda_version)
else: