mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cutlass backend] check against arch >= 100 (#145812)
Summary: Want to add a guard against silent fallback to SM90. GenerateSM100 was just added 3 days ago. https://github.com/NVIDIA/cutlass/blame/main/python/cutlass_library/generator.py#L8896 It should show up in CUTLASS 3.8 (not pinned yet). Test Plan: ci Differential Revision: D68748705 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145812 Approved by: https://github.com/chenyang78, https://github.com/ColinPeppler, https://github.com/Aidyn-A
This commit is contained in:
parent
bab35eb26a
commit
0d5fb0941f
|
|
@ -114,8 +114,19 @@ def try_import_cutlass() -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(8)
|
||||
def _normalize_cuda_arch(arch: str) -> str:
|
||||
if int(arch) >= 90:
|
||||
if int(arch) >= 100:
|
||||
log.warning(
|
||||
"Detected CUDA architecture >= 100: %s. We will generate operations with "
|
||||
"GenerateSM100 (if available) and GenerateSM90. Please file an "
|
||||
"issue for any problems and feedback. ",
|
||||
arch,
|
||||
)
|
||||
|
||||
if int(arch) >= 100:
|
||||
return "100"
|
||||
elif int(arch) >= 90:
|
||||
return "90"
|
||||
elif int(arch) >= 80:
|
||||
return "80"
|
||||
|
|
@ -186,7 +197,15 @@ def _gen_ops_cached(arch, version) -> list[Any]:
|
|||
)
|
||||
manifest = cutlass_manifest.Manifest(args)
|
||||
|
||||
if arch == "90":
|
||||
if arch == "100":
|
||||
try:
|
||||
from cutlass_generator import GenerateSM100 # type: ignore[import]
|
||||
|
||||
GenerateSM100(manifest, args.cuda_version)
|
||||
except ImportError:
|
||||
log.warning("Cannot find GenerateSM100. Only GenerateSM90 will be used. ")
|
||||
cutlass_generator.GenerateSM90(manifest, args.cuda_version)
|
||||
elif arch == "90":
|
||||
cutlass_generator.GenerateSM90(manifest, args.cuda_version)
|
||||
cutlass_generator.GenerateSM80(manifest, args.cuda_version)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user