mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Skip symmetric memory tests calling _scaled_mm on CCC < 8.9 (#164251)
This avoids them failing on e.g. A100 GPUs with > RuntimeError: torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+ Pull Request resolved: https://github.com/pytorch/pytorch/pull/164251 Approved by: https://github.com/Skylion007, https://github.com/kwen2501
This commit is contained in:
parent
fa90090735
commit
8bb71c07c4
|
|
@ -4,7 +4,7 @@ import itertools
|
|||
import os
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from unittest import skip, skipIf
|
||||
from unittest import skip, skipIf, skipUnless
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -25,6 +25,7 @@ from torch.distributed._symmetric_memory import (
|
|||
from torch.testing._internal.common_cuda import (
|
||||
_get_torch_cuda_version,
|
||||
SM100OrLater,
|
||||
SM89OrLater,
|
||||
SM90OrLater,
|
||||
xfailIfSM100OrLater,
|
||||
)
|
||||
|
|
@ -430,6 +431,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
|||
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
|
||||
@parametrize("gather_dim", [0, 1])
|
||||
@parametrize(
|
||||
"scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"]
|
||||
|
|
@ -545,6 +547,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
|||
|
||||
@skip_if_rocm_multiprocess # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
|
||||
@parametrize("scatter_dim", [0, 1])
|
||||
@parametrize("rowwise", [True, False])
|
||||
def test_fused_scaled_matmul_reduce_scatter(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user