[ROCm] Enable post-merge trunk workflow on MI300 runners; skip and fix MI300 related failed tests (#143673)

This PR
* makes changes to the workflow files and scripts so we can run CI workflows on the MI300 runners
* skips and fixes several tests, failed on MI300, observed in https://github.com/pytorch/pytorch/pull/140989

Skipped due to unsupported Float8_e4m3fn data type on MI300 (need to update test code to use datatypes supported by MI300):
- distributed.tensor.parallel.test_micro_pipeline_tp.py::MicroPipelineTPTest::test_fuse_all_gather_scaled_matmul_A_dims_\*_gather_dim_\* (24 tests across inductor/distributed configs)
- distributed.tensor.parallel.test_micro_pipeline_tp.py::test_fuse_scaled_matmul_reduce_scatter_A_dims_\*_scatter_dim_\* (12 tests across inductor/distributed configs))
- inductor.test_loop_ordering::LoopOrderingTest::test_fp8_cast_and_t
- inductor.test_loop_ordering::LoopOrderingTest::test_fp8_pattern_2

Skipped due to AssertionError on MI300:
- inductor.test_mkldnn_pattern_matcher.py::test_qconv2d_int8_mixed_bf16
- distributed._tools.test_sac_ilp::TestSACILP::test_sac_ilp_case1

Skipped:
- test_cuda.py::TestCudaMallocAsync::test_clock_speed
- test_cuda.py::TestCudaMallocAsync::test_power_draw
- test_torch.py::TestTorchDeviceTypeCUDA::test_deterministic_cumsum_cuda

Skipped flaky tests on MI300:
- distributed.test_c10d_gloo.py::ProcessGroupGlooTest::test_gather_stress_cuda
- inductor.test_cpu_repro::CPUReproTests::test_lstm_packed_unbatched_False* (256 tests)

Fixed:
- test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_float8_basics_cuda

Features:
- inductor/test_fp8.py - declare a new function to convert FP8 datatypes to ROCm supported FP8 datatypes. It keeps test names for CUDA and ROCm and allows to enable Inductor FP8 tests on CPU

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143673
Approved by: https://github.com/jeffdaily, https://github.com/malfet, https://github.com/pruthvistony

Co-authored-by: saienduri <saimanas.enduri@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Dmitry Nikolaev 2025-01-09 05:18:57 +00:00 committed by PyTorch MergeBot
parent 0d08084f1a
commit d4871750d9
16 changed files with 117 additions and 80 deletions

View File

@ -497,7 +497,7 @@ docker build \
--build-arg "NINJA_VERSION=${NINJA_VERSION:-}" \
--build-arg "KATEX=${KATEX:-}" \
--build-arg "ROCM_VERSION=${ROCM_VERSION:-}" \
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx90a}" \
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx90a;gfx942}" \
--build-arg "IMAGE_NAME=${IMAGE_NAME}" \
--build-arg "UCX_COMMIT=${UCX_COMMIT}" \
--build-arg "UCC_COMMIT=${UCC_COMMIT}" \

View File

@ -17,6 +17,10 @@ runs:
set -ex
diskspace_cutoff=${{ inputs.diskspace-cutoff }}
docker_root_dir=$(docker info -f '{{.DockerRootDir}}')
if [ ! -d "$docker_root_dir" ]; then
echo "Docker root directory ($docker_root_dir) does not exist. Skipping disk space check."
exit 0
fi
diskspace=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //')
msg="Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified"
if [[ "$diskspace" -ge "$diskspace_cutoff" ]] ; then

View File

@ -5,20 +5,6 @@ description: Set up ROCm host for CI
runs:
using: composite
steps:
- name: Set DOCKER_HOST
shell: bash
run: echo "DOCKER_HOST=unix:///run/user/$(id -u)/docker.sock" >> "${GITHUB_ENV}"
- name: Remove leftover Docker config file
shell: bash
continue-on-error: true
run: |
set -ex
cat ~/.docker/config.json || true
# https://stackoverflow.com/questions/64455468/error-when-logging-into-ecr-with-docker-login-error-saving-credentials-not
rm -f ~/.docker/config.json
- name: Stop all running docker containers
if: always()
shell: bash
@ -111,8 +97,10 @@ runs:
shell: bash
run: |
# All GPUs are visible to the runner; visibility, if needed, will be set by run_test.py.
# Add render group for container creation.
render_gid=`cat /etc/group | grep render | cut -d: -f3`
# The --group-add daemon and --group-add bin are needed in the Ubuntu 24.04 and Almalinux OSs respectively.
# This is due to the device files (/dev/kfd & /dev/dri) being owned by video group on bare metal.
# This video group ID maps to subgid 1 inside the docker image due to the /etc/subgid entries.
# The group name corresponding to group ID 1 can change depending on the OS, so both are necessary.
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon --group-add bin" >> "${GITHUB_ENV}"
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device /dev/dri --group-add video --group-add $render_gid --group-add daemon --group-add bin --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --network=host" >> "${GITHUB_ENV}"

View File

@ -36,12 +36,12 @@ jobs:
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 1, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 2, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 3, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 4, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 5, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 6, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
]}
secrets: inherit

View File

@ -175,9 +175,9 @@ jobs:
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2" },
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.4" },
{ config: "default", shard: 1, num_shards: 2, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 2, num_shards: 2, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "distributed", shard: 1, num_shards: 1, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.4' || 'linux.rocm.gpu.4' }}" },
]}
secrets: inherit

View File

@ -19,7 +19,13 @@ from torch.distributed._tools.sac_ilp import (
sac_milp,
)
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
from torch.testing._internal.common_utils import (
MI300_ARCH,
run_tests,
skipIfRocmArch,
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
@ -131,6 +137,7 @@ class TestSACILP(TestCase):
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
@skipIfRocmArch(MI300_ARCH)
def test_sac_ilp_case1(self):
"""
This is a case where the memory budget is either binding or too tight,

View File

@ -29,10 +29,9 @@ from torch.distributed.tensor.parallel import (
)
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
instantiate_parametrized_tests,
MI300_ARCH,
parametrize,
run_tests,
runOnRocmArch,
skipIfRocm,
TestCase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
@ -241,7 +240,7 @@ class MicroPipelineTPTest(TestCase):
self.assertNotIn("all_gather_into_tensor", code)
self.assertEqual("return_A=True" in code, return_A)
@runOnRocmArch(MI300_ARCH)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("A_dims", [2, 3])
@parametrize("gather_dim", [0, 1, 2])
@ -344,7 +343,7 @@ class MicroPipelineTPTest(TestCase):
self.assertIn("fused_matmul_reduce_scatter", code)
self.assertNotIn("reduce_scatter_tensor", code)
@runOnRocmArch(MI300_ARCH)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("A_dims", [2, 3])
@parametrize("scatter_dim", [0, 1, 2])

View File

@ -49,9 +49,11 @@ from torch.testing._internal.common_distributed import (
verify_ddp_error_logged,
)
from torch.testing._internal.common_utils import (
MI300_ARCH,
retry_on_connect_failures,
run_tests,
skip_but_pass_in_sandcastle,
skipIfRocmArch,
TestCase,
)
@ -1097,6 +1099,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
self._test_gather_stress(inputs, lambda t: t.clone())
@skip_if_lt_x_gpu(2)
@skipIfRocmArch(MI300_ARCH)
@requires_gloo()
def test_gather_stress_cuda(self):
inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]

View File

@ -36,6 +36,7 @@ from torch.testing._internal.common_utils import (
skipIfRocm,
slowTest,
TEST_MKL,
TEST_WITH_ROCM,
xfailIfS390X,
)
from torch.utils._python_dispatch import TorchDispatchMode
@ -583,6 +584,12 @@ class CPUReproTests(TestCase):
batch_size,
seq_len,
):
if (
TEST_WITH_ROCM
and not unbatched
and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
):
self.skipTest("Flaky on MI300 with unbatched=False")
self._test_lstm_packed(
unbatched,
input_size,

View File

@ -2,6 +2,7 @@
import functools
import unittest
from typing import List, Tuple, Union
import torch
from torch import Tensor
@ -87,6 +88,33 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
return x_fp8, inverse_scale
def _fix_fp8_dtype_for_rocm(
dtype: Union[torch.dtype, List[torch.dtype], Tuple[torch.dtype]], device
) -> Union[torch.dtype, List[torch.dtype], Tuple[torch.dtype]]:
# This function is used to change FP8 data types
# with MI300 supported FP8 types if device is GPU:
# e4m3fn -> e4m3fnuz
# e5m2 -> e5m2fnuz
# Supports single, typle and list of dtypes
# Keeps the same test name for CUDA and ROCm
# Also it allows to enable FP8 inductor tests for CPU
if (
torch.version.hip
and ("cuda" in device)
and ("gfx94" in torch.cuda.get_device_properties(0).gcnArchName.split(":")[0])
):
# MI300 uses different float8 dtypes
if isinstance(dtype, tuple):
return tuple(_fix_fp8_dtype_for_rocm(x, device) for x in dtype)
if isinstance(dtype, list):
return [_fix_fp8_dtype_for_rocm(x, device) for x in dtype]
if dtype == torch.float8_e4m3fn:
return torch.float8_e4m3fnuz
elif dtype == torch.float8_e5m2:
return torch.float8_e5m2fnuz
return dtype
@instantiate_parametrized_tests
class TestFP8Types(TestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@ -116,9 +144,8 @@ class TestFP8Types(TestCase):
def test_eager_fallback(self, dtype: torch.dtype):
weight_shape = (32, 16)
e4m3_type = (
torch.float8_e4m3fn if torch.version.hip is None else torch.float8_e4m3fnuz
)
e4m3_type = torch.float8_e4m3fn
e4m3_type = _fix_fp8_dtype_for_rocm(e4m3_type, device="cuda")
def fp8_matmul_unwrapped(x):
a_scale = torch.Tensor([1.0]).to(device="cuda")
@ -156,13 +183,9 @@ class TestFP8Types(TestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
@parametrize("shape", ("15,3,13", "4,2048,4096"))
@parametrize(
"dst_types",
[(torch.float8_e4m3fn, torch.float8_e5m2)]
if torch.version.hip is None
else [(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)],
)
@parametrize("dst_types", [(torch.float8_e4m3fn, torch.float8_e5m2)])
def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple):
dst_types = _fix_fp8_dtype_for_rocm(dst_types, device="cuda")
e4m3, e5m2 = dst_types
def fp8_cast(x):
@ -204,16 +227,13 @@ class TestFP8Types(TestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
@parametrize(
"dst_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
if torch.version.hip is None
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("shape", ("16,16,16", "4,2048,4096"))
def test_to_fp8_saturated(
self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str
):
dst_dtype = _fix_fp8_dtype_for_rocm(dst_dtype, device="cuda")
def fp8_saturated(x, dtype):
return _to_fp8_saturated(x, dtype)
@ -229,14 +249,10 @@ class TestFP8Types(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize(
"float8_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
if torch.version.hip is None
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda")
shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape
@ -258,14 +274,10 @@ class TestFP8Types(TestCase):
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize(
"float8_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
if torch.version.hip is None
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda")
shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape
@ -293,17 +305,13 @@ class TestFP8Types(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize(
"float8_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
if torch.version.hip is None
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("amax_keep_dim", (True, False))
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_layernorm_fp8_quant(
self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: str
):
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda")
shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape
@ -339,12 +347,7 @@ class TestFP8Types(TestCase):
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize(
"float8_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
if torch.version.hip is None
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("shape", ("4,2048,4096",))
@parametrize("keepdim", (False, True))
def test_layernorm_fp8_quant_benchmark(
@ -353,6 +356,7 @@ class TestFP8Types(TestCase):
shape: str,
keepdim: bool,
):
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda")
shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape

View File

@ -19,6 +19,7 @@ from torch._inductor.test_operators import realize
from torch._inductor.utils import sympy_index_symbol
from torch._inductor.virtualized import ops, V
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.utils._pytree import tree_map
from torch.utils._sympy.functions import ModularIndexing
@ -398,6 +399,7 @@ class LoopOrderingTest(TestCase):
self.do_acc_test(f, x)
self.assertEqual(1, metrics.generated_kernel_count)
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
def test_fp8_cast_and_t(self):
"""
@ -420,6 +422,7 @@ class LoopOrderingTest(TestCase):
self.do_acc_test(f, x, scale)
self.assertEqual(1, metrics.generated_kernel_count)
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
def test_fp8_pattern_2(self):
"""

View File

@ -23,9 +23,11 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_FBCODE,
IS_LINUX,
MI300_ARCH,
parametrize,
skipIfNoXPU,
skipIfRocm,
skipIfRocmArch,
TEST_ACL,
TEST_MKL,
xfailIfACL,
@ -1032,6 +1034,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skipIfRocmArch(MI300_ARCH)
def test_qconv2d_int8_mixed_bf16(self):
r"""
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.

View File

@ -60,6 +60,7 @@ from torch.testing._internal.common_utils import (
IS_SANDCASTLE,
IS_WINDOWS,
load_tests,
MI300_ARCH,
NO_MULTIPROCESSING_SPAWN,
parametrize,
run_tests,
@ -68,6 +69,7 @@ from torch.testing._internal.common_utils import (
skipCUDAMemoryLeakCheckIf,
skipCUDANonDefaultStreamIf,
skipIfRocm,
skipIfRocmArch,
slowTest,
subtest,
TemporaryFileName,
@ -4055,10 +4057,12 @@ class TestCudaMallocAsync(TestCase):
self.assertTrue(num_bytes // 32 <= mem_bytes <= num_bytes * 32)
@unittest.skipIf(TEST_PYNVML, "pynvml/amdsmi is not available")
@skipIfRocmArch(MI300_ARCH)
def test_power_draw(self):
self.assertTrue(torch.cuda.power_draw() >= 0)
@unittest.skipIf(TEST_PYNVML, "pynvml/amdsmi is not available")
@skipIfRocmArch(MI300_ARCH)
def test_clock_speed(self):
self.assertTrue(torch.cuda.clock_rate() >= 0)

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: linear algebra"]
import contextlib
import unittest
from itertools import product
from functools import partial
@ -351,20 +352,20 @@ class TestFP8MatmulCuda(TestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_basics(self, device) -> None:
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
# hipblaslt does not yet support mixed e4m3_type input
if torch.version.hip is None:
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
with self.assertRaises(RuntimeError):
# supported on ROCm but fails on CUDA
ctx = self.assertRaises(RuntimeError) if torch.version.hip is None else contextlib.nullcontext()
with ctx:
self._test_tautological_mm(device, e5m2_type, e5m2_type)
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
# hipblaslt does not yet support bfloat16 output
if torch.version.hip is None:
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
with self.assertRaises(RuntimeError):
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
with self.assertRaises(AssertionError if torch.version.hip else RuntimeError):
self._test_tautological_mm(device, out_dtype=e5m2_type)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)

View File

@ -35,9 +35,9 @@ from torch.testing._internal.common_optimizers import (
optim_db, optims, _get_optim_inputs_including_global_cliquey_kwargs)
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, run_tests, IS_JETSON,
MI300_ARCH, TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, run_tests, IS_JETSON,
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, slowTestIf,
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfRocmArch, skipIfTorchInductor, load_tests, slowTest, slowTestIf,
skipIfCrossRef, TEST_WITH_CROSSREF, skipIfTorchDynamo, skipRocmIfTorchInductor, set_default_dtype,
skipCUDAMemoryLeakCheckIf, BytesIOContext,
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
@ -1738,6 +1738,7 @@ else:
'embedding_bag_backward_cuda_max',
torch.device(device).type == 'cuda')
@skipIfRocmArch(MI300_ARCH)
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
@onlyCUDA
def test_deterministic_cumsum(self, device):

View File

@ -1855,6 +1855,19 @@ def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"
return dec_fn(func)
return dec_fn
def skipIfRocmArch(arch: Tuple[str, ...]):
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
if TEST_WITH_ROCM:
prop = torch.cuda.get_device_properties(0)
if prop.gcnArchName.split(":")[0] in arch:
reason = f"skipIfRocm: test skipped on {arch}"
raise unittest.SkipTest(reason)
return fn(self, *args, **kwargs)
return wrap_fn
return dec_fn
def runOnRocm(fn):
@wraps(fn)
def wrapper(*args, **kwargs):