[ROCm][CI] Add support for gfx1100 in rocm workflow + test skips (#148355)

This PR adds infrastructure support for gfx1100 in the rocm workflow. Nodes have been allocated for this effort.
@dnikolaev-amd contributed all the test skips.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148355
Approved by: https://github.com/jeffdaily

Co-authored-by: Dmitry Nikolaev <dmitry.nikolaev@amd.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
amdfaa 2025-10-07 22:36:25 +00:00 committed by PyTorch MergeBot
parent 9f5e1beaf3
commit 955f21dc2c
6 changed files with 60 additions and 16 deletions

View File

@ -344,7 +344,7 @@ docker build \
--build-arg "NINJA_VERSION=${NINJA_VERSION:-}" \
--build-arg "KATEX=${KATEX:-}" \
--build-arg "ROCM_VERSION=${ROCM_VERSION:-}" \
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx90a;gfx942}" \
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx90a;gfx942;gfx1100}" \
--build-arg "IMAGE_NAME=${IMAGE_NAME}" \
--build-arg "UCX_COMMIT=${UCX_COMMIT}" \
--build-arg "UCC_COMMIT=${UCC_COMMIT}" \

View File

@ -59,3 +59,29 @@ jobs:
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-rocm-py3_10-gfx1100-test:
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3_10-gfx1100
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
]}
tests-to-include: >
test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
inductor/test_flex_attention inductor/test_max_autotune
secrets: inherit

View File

@ -49,7 +49,9 @@ from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_WINDOWS,
NAVI_ARCH,
parametrize,
skipIfRocmArch,
TEST_WITH_ROCM,
)
from torch.testing._internal.logging_utils import multiple_logs_to_string
@ -1284,6 +1286,7 @@ class TestMaxAutotune(TestCase):
self.assertIn("NoValidChoicesError", str(context.exception))
@skipIfRocmArch(NAVI_ARCH)
def test_non_contiguous_input_mm(self):
"""
Make sure the triton template can work with non-contiguous inputs without crash.
@ -1302,6 +1305,7 @@ class TestMaxAutotune(TestCase):
act = f(x, y)
torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2)
@skipIfRocmArch(NAVI_ARCH)
def test_non_contiguous_input_addmm(self):
b = torch.randn((768), dtype=torch.bfloat16, device=GPU_TYPE)
x = rand_strided(
@ -1317,6 +1321,7 @@ class TestMaxAutotune(TestCase):
act = f(x, y)
torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2)
@skipIfRocmArch(NAVI_ARCH)
def test_non_contiguous_input_bmm(self):
x = rand_strided(
(1, 50257, 2048), (0, 1, 50304), dtype=torch.bfloat16, device=GPU_TYPE

View File

@ -23,8 +23,8 @@ from torch.testing._internal.common_utils import \
TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices,
make_fullrank_matrices_with_distinct_singular_values,
freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo,
setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest,
runOnRocmArch, MI300_ARCH, TEST_CUDA)
skipIfRocmArch, setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest,
runOnRocmArch, MI300_ARCH, NAVI_ARCH, TEST_CUDA)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver,
onlyCPU, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
@ -9706,6 +9706,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self.assertEqual(out_ref, out2.cpu())
@onlyCUDA
@skipIfRocmArch(NAVI_ARCH)
@skipCUDAIfNotRocm
@unittest.skipIf(not blaslt_supported_device(), "blasLt not supported on current device")
@setBlasBackendsToDefaultFinally

View File

@ -33,6 +33,9 @@ from torch.testing._internal.common_device_type import (
from torch.testing._internal.common_utils import (
IS_JETSON,
IS_WINDOWS,
NAVI_ARCH,
getRocmVersion,
isRocmArchAnyOf,
parametrize,
run_tests,
skipIfRocm,
@ -152,6 +155,9 @@ class TestMatmulCuda(InductorTestCase):
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
if (TEST_WITH_ROCM and backend == "cublas" and isRocmArchAnyOf(NAVI_ARCH) and
getRocmVersion() < (6, 4) and dtype == torch.float16 and size >= 10000):
self.skipTest(f"failed on Navi for ROCm6.3 due to hipblas backend, dtype={dtype} and size={size}")
self.cublas_addmm(size, dtype, False)
@onlyCUDA

View File

@ -1975,15 +1975,20 @@ def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"
return dec_fn(func)
return dec_fn
def getRocmArchName(device_index: int = 0):
return torch.cuda.get_device_properties(device_index).gcnArchName
def isRocmArchAnyOf(arch: tuple[str, ...]):
rocmArch = getRocmArchName()
return any(x in rocmArch for x in arch)
def skipIfRocmArch(arch: tuple[str, ...]):
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
if TEST_WITH_ROCM:
prop = torch.cuda.get_device_properties(0)
if prop.gcnArchName.split(":")[0] in arch:
reason = f"skipIfRocm: test skipped on {arch}"
raise unittest.SkipTest(reason)
if TEST_WITH_ROCM and isRocmArchAnyOf(arch):
reason = f"skipIfRocm: test skipped on {arch}"
raise unittest.SkipTest(reason)
return fn(self, *args, **kwargs)
return wrap_fn
return dec_fn
@ -2001,11 +2006,9 @@ def runOnRocmArch(arch: tuple[str, ...]):
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
if TEST_WITH_ROCM:
prop = torch.cuda.get_device_properties(0)
if prop.gcnArchName.split(":")[0] not in arch:
reason = f"skipIfRocm: test only runs on {arch}"
raise unittest.SkipTest(reason)
if TEST_WITH_ROCM and not isRocmArchAnyOf(arch):
reason = f"skipIfRocm: test only runs on {arch}"
raise unittest.SkipTest(reason)
return fn(self, *args, **kwargs)
return wrap_fn
return dec_fn
@ -2055,15 +2058,18 @@ def skipIfHpu(fn):
fn(*args, **kwargs)
return wrapper
def getRocmVersion() -> tuple[int, int]:
from torch.testing._internal.common_cuda import _get_torch_rocm_version
rocm_version = _get_torch_rocm_version()
return (rocm_version[0], rocm_version[1])
# Skips a test on CUDA if ROCm is available and its version is lower than requested.
def skipIfRocmVersionLessThan(version=None):
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
if TEST_WITH_ROCM:
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
rocm_version_tuple = getRocmVersion()
if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
reason = f"ROCm {rocm_version_tuple} is available but {version} required"
raise unittest.SkipTest(reason)