mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Skip symmetric memory tests calling _scaled_mm on CCC < 8.9 (#164251)
This avoids them failing on e.g. A100 GPUs with > RuntimeError: torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+ Pull Request resolved: https://github.com/pytorch/pytorch/pull/164251 Approved by: https://github.com/Skylion007, https://github.com/kwen2501
This commit is contained in:
parent
fa90090735
commit
8bb71c07c4
|
|
@ -4,7 +4,7 @@ import itertools
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from unittest import skip, skipIf
|
from unittest import skip, skipIf, skipUnless
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
@ -25,6 +25,7 @@ from torch.distributed._symmetric_memory import (
|
||||||
from torch.testing._internal.common_cuda import (
|
from torch.testing._internal.common_cuda import (
|
||||||
_get_torch_cuda_version,
|
_get_torch_cuda_version,
|
||||||
SM100OrLater,
|
SM100OrLater,
|
||||||
|
SM89OrLater,
|
||||||
SM90OrLater,
|
SM90OrLater,
|
||||||
xfailIfSM100OrLater,
|
xfailIfSM100OrLater,
|
||||||
)
|
)
|
||||||
|
|
@ -430,6 +431,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||||
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
||||||
)
|
)
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
|
||||||
@parametrize("gather_dim", [0, 1])
|
@parametrize("gather_dim", [0, 1])
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"]
|
"scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"]
|
||||||
|
|
@ -545,6 +547,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||||
|
|
||||||
@skip_if_rocm_multiprocess # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes
|
@skip_if_rocm_multiprocess # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
|
||||||
@parametrize("scatter_dim", [0, 1])
|
@parametrize("scatter_dim", [0, 1])
|
||||||
@parametrize("rowwise", [True, False])
|
@parametrize("rowwise", [True, False])
|
||||||
def test_fused_scaled_matmul_reduce_scatter(
|
def test_fused_scaled_matmul_reduce_scatter(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user