Skip symmetric memory tests calling _scaled_mm on CCC < 8.9 (#164251)

This avoids them failing on e.g. A100 GPUs with
> RuntimeError: torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164251
Approved by: https://github.com/Skylion007, https://github.com/kwen2501
This commit is contained in:
Alexander Grund 2025-10-01 03:26:17 +00:00 committed by PyTorch MergeBot
parent fa90090735
commit 8bb71c07c4

View File

@ -4,7 +4,7 @@ import itertools
import os import os
import random import random
from contextlib import nullcontext from contextlib import nullcontext
from unittest import skip, skipIf from unittest import skip, skipIf, skipUnless
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -25,6 +25,7 @@ from torch.distributed._symmetric_memory import (
from torch.testing._internal.common_cuda import ( from torch.testing._internal.common_cuda import (
_get_torch_cuda_version, _get_torch_cuda_version,
SM100OrLater, SM100OrLater,
SM89OrLater,
SM90OrLater, SM90OrLater,
xfailIfSM100OrLater, xfailIfSM100OrLater,
) )
@ -430,6 +431,7 @@ class AsyncTPTest(MultiProcContinuousTest):
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch" not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
) )
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
@parametrize("gather_dim", [0, 1]) @parametrize("gather_dim", [0, 1])
@parametrize( @parametrize(
"scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"] "scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"]
@ -545,6 +547,7 @@ class AsyncTPTest(MultiProcContinuousTest):
@skip_if_rocm_multiprocess # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes @skip_if_rocm_multiprocess # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
@parametrize("scatter_dim", [0, 1]) @parametrize("scatter_dim", [0, 1])
@parametrize("rowwise", [True, False]) @parametrize("rowwise", [True, False])
def test_fused_scaled_matmul_reduce_scatter( def test_fused_scaled_matmul_reduce_scatter(