diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index 57c7175daab..6b91bce7949 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -4,7 +4,7 @@ import itertools import os import random from contextlib import nullcontext -from unittest import skip, skipIf +from unittest import skip, skipIf, skipUnless import torch import torch.distributed as dist @@ -25,6 +25,7 @@ from torch.distributed._symmetric_memory import ( from torch.testing._internal.common_cuda import ( _get_torch_cuda_version, SM100OrLater, + SM89OrLater, SM90OrLater, xfailIfSM100OrLater, ) @@ -430,6 +431,7 @@ class AsyncTPTest(MultiProcContinuousTest): not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch" ) @skip_if_lt_x_gpu(2) + @skipUnless(SM89OrLater, "Requires compute capability >= 8.9") @parametrize("gather_dim", [0, 1]) @parametrize( "scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"] @@ -545,6 +547,7 @@ class AsyncTPTest(MultiProcContinuousTest): @skip_if_rocm_multiprocess # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes @skip_if_lt_x_gpu(2) + @skipUnless(SM89OrLater, "Requires compute capability >= 8.9") @parametrize("scatter_dim", [0, 1]) @parametrize("rowwise", [True, False]) def test_fused_scaled_matmul_reduce_scatter(