[ROCm] Enabling several UTs (#161715)

All these UTs are working as is, just removing the skip
- test_p2p_ipc
- test_repros.py: working, added fp8 support
- test_activation_checkpointing.py
- test_content_store.py
- test_cuda_multigpu.py
- test_compute_comm_reordering.py
- test_segment_reductions.py
- test_dataloader.py
- test_math_ops.py
- test_loop_ordering.py
- test_control_flow.py
- distributed_test.py
- test_mem_tracker.py
- test_fsdp_optim_state.py
- test_fully_shard_mixed_precision.py: skippped for < ROCm7.0
- test_aot_inductor_custom_ops.py
- test_c10d_ops_nccl.py
- test_eager_transforms.py
- test_sparse_csr.py
- test_inductor_collectives.py
- test_fake_tensor.py
- test_cupy_as_tensor.py
- test_cuda.py: enable UTs that are working
- test_matmul_cuda.py: enable UTs that are working

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161715
Approved by: https://github.com/msaroufim

Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
This commit is contained in:
Prachi Gupta 2025-09-09 15:49:21 +00:00 committed by PyTorch MergeBot
parent 3ea6868049
commit c0142f5c06
25 changed files with 27 additions and 80 deletions

View File

@ -28,7 +28,11 @@ from torch.testing._internal.common_fsdp import (
patch_reduce_scatter,
reduce_scatter_with_assert,
)
from torch.testing._internal.common_utils import run_tests, skipIfRocm, TEST_HPU
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocmVersionLessThan,
TEST_HPU,
)
device_type = torch.device(get_devtype())
@ -86,7 +90,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
use_shard_placement_fn_vals.append(True)
return use_shard_placement_fn_vals
@skipIfRocm # regressed in ROCm 6.4, but ROCm 6.5 fixes it
@skipIfRocmVersionLessThan((7, 0))
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_compute_dtype(self):
@ -166,7 +170,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
self.assertEqual(fsdp_loss, ref_loss)
check_sharded_parity(self, ref_model, model)
@skipIfRocm # regressed in ROCm 6.4, but ROCm 6.5 fixes it
@skipIfRocmVersionLessThan((7, 0))
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_reduce_dtype(self):

View File

@ -7,7 +7,6 @@ import torch.nn as nn
from torch.distributed._tools.mem_tracker import MemTracker
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocm,
skipIfTorchDynamo,
TEST_CUDA,
TEST_XPU,
@ -34,7 +33,6 @@ class TestMemTracker(TestCase):
@unittest.skipIf(
not TEST_CUDA and not TEST_XPU, "Neither CUDA or XPU is not available"
)
@skipIfRocm()
def test_accelerator_tracker_equivalence(
self,
):

View File

@ -116,7 +116,6 @@ class DistributedUtilTest(TestCase):
timeout=1,
)
@skipIfRocm
def test_create_store_timeout_on_worker(self):
with self.assertRaises(DistNetworkError):
# use any available port (port 0) since timeout is expected

View File

@ -38,7 +38,6 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfRocm,
TEST_WITH_DEV_DBG_ASAN,
)
@ -514,7 +513,6 @@ class TestFSDPOptimState(FSDPTest):
continue
self.assertEqual(full_osd_value, ref_osd_pg[name])
@skipIfRocm
@skip_if_lt_x_gpu(2)
@parametrize("state_dict_type", STATE_DICT_TYPES)
@parametrize("use_multiple_param_groups", [False, True])

View File

@ -24,7 +24,7 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel,
SequenceParallel,
)
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_unless_torch_gpu,
@ -695,7 +695,6 @@ class DistMathOpsTest(DTensorTestBase):
self.assertEqual(grad1_norm.device_mesh, mesh_y)
@with_comms
@skipIfRocm
def test_foreach_add_different_mesh(self):
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(

View File

@ -33,7 +33,6 @@ from torch.testing._internal.common_distributed import (
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
TEST_WITH_DEV_DBG_ASAN,
)
@ -319,7 +318,6 @@ class ProcessGroupNCCLOpTest(MultiProcContinuousTest):
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@skipIfRocm()
def test_nccl_watchdog_cudagraph(self):
# test that the watchdog does not crash graphs with disallowed event query
pg = self.pg

View File

@ -29,7 +29,6 @@ from torch.testing._internal.common_distributed import (
requires_accelerator_dist_backend,
)
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_GPU
@ -368,7 +367,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skipIfRocm
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(

View File

@ -8,11 +8,7 @@ from dataclasses import dataclass
import torch
from torch.multiprocessing.reductions import reduce_tensor
from torch.testing._internal.common_distributed import MultiProcContinuousTest
from torch.testing._internal.common_utils import (
requires_cuda_p2p_access,
run_tests,
skipIfRocm,
)
from torch.testing._internal.common_utils import requires_cuda_p2p_access, run_tests
# So that tests are written in device-agnostic way
@ -63,7 +59,6 @@ class CupyAsTensorTest(MultiProcContinuousTest):
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
def test_cupy_as_tensor(self) -> None:
"""
Test that torch.as_tensor works for cupy array interface

View File

@ -45,6 +45,8 @@ from torch.testing._internal.common_utils import (
parametrize,
requires_cuda,
skipIfRocm,
TEST_XPU,
xfailIf,
)
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils._python_dispatch import TorchDispatchMode
@ -266,6 +268,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1728
@skipIfRocm
def test_eager_async_allreduce_inductor_wait(self):
import torch.distributed as dist

View File

@ -7,11 +7,7 @@
import torch
from torch.multiprocessing.reductions import reduce_tensor
from torch.testing._internal.common_distributed import MultiProcContinuousTest
from torch.testing._internal.common_utils import (
requires_cuda_p2p_access,
run_tests,
skipIfRocm,
)
from torch.testing._internal.common_utils import requires_cuda_p2p_access, run_tests
# So that tests are written in device-agnostic way
@ -34,7 +30,6 @@ class P2PIpcTest(MultiProcContinuousTest):
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
def test_p2p_ipc(self) -> None:
"""
Test that cross-process P2P access works, by reducing a tensor,

View File

@ -644,7 +644,7 @@ class SymmMemEmptySetDeviceTest(MultiProcessTestCase):
symm_mem_hdl.barrier()
@runOnRocmArch(MI300_ARCH)
@skipIfRocm
@skip_if_lt_x_gpu(2)
@parametrize("set_device", [True, False])
def test_empty_strided_p2p(self, set_device: bool) -> None:

View File

@ -18,7 +18,7 @@ from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import CompileCounterWithBackend
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.two_tensor import TwoTensor
@ -1364,7 +1364,6 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
self.assertEqual(out, out_compiled)
self.assertEqual(input.grad, input_compiled.grad)
@skipIfRocm
@requires_cuda_and_triton
def test_autocast_flash_attention(self, device):
def fn(primals_1, primals_2, primals_3):

View File

@ -60,13 +60,16 @@ from torch.testing._internal.common_cuda import (
SM70OrLater,
TEST_CUDA,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_device_type import (
E4M3_MAX_POS,
e4m3_type,
instantiate_device_type_tests,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
serialTest,
skipIfHpu,
skipIfRocm,
skipIfWindows,
TEST_WITH_ROCM,
)
@ -7500,7 +7503,6 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
out = f_compiled(x, s0, s1, s2)
self.assertEqual(out_ref, out)
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "requires gpu with fp8 support")
@requires_cuda
def test_partitioner_saves_weights_for_bw(self):
@ -7512,9 +7514,9 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
return a
def scale(t, amax_t):
max_v = torch.finfo(torch.float8_e4m3fn).max
max_v = E4M3_MAX_POS
scale_t = torch.clamp(amax_t.float(), min=1e-12) / max_v
t_fp8 = mul_tiled(t, scale_t.reciprocal()).to(torch.float8_e4m3fn)
t_fp8 = mul_tiled(t, scale_t.reciprocal()).to(e4m3_type)
return t_fp8, scale_t
def matmul(first, amax_first, second_t, amax_second_t, bias):

View File

@ -33,7 +33,6 @@ from torch.testing._internal.common_utils import (
requires_cuda,
run_tests,
skipIfCrossRef,
skipIfRocm,
skipIfTorchDynamo,
TEST_WITH_CROSSREF,
TEST_WITH_TORCHDYNAMO,
@ -1862,7 +1861,6 @@ def forward(self, pred_1, x_1):
)
self.assertEqual(grads, expected_grads)
@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("reverse", [False, True])
@ -2007,7 +2005,6 @@ def forward(self, pred_1, x_1):
# TODO: Does not work because of the usage of vmap within associative_scan
# The paT206899919 rameterization is commented out for the moment and the test is marked with expected fail
# Fails with: AssertionError: scan is not an OpOverload
@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@unittest.expectedFailure
@ -3775,7 +3772,6 @@ class AssociativeScanTests(TestCase):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -3859,7 +3855,6 @@ class AssociativeScanTests(TestCase):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -3921,7 +3916,6 @@ class AssociativeScanTests(TestCase):
inputs=x,
)
@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@ -3940,7 +3934,6 @@ class AssociativeScanTests(TestCase):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -4044,7 +4037,6 @@ class AssociativeScanTests(TestCase):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -4230,7 +4222,6 @@ class GraphModule(torch.nn.Module):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -4279,7 +4270,6 @@ class GraphModule(torch.nn.Module):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -4330,7 +4320,6 @@ class GraphModule(torch.nn.Module):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -4526,7 +4515,6 @@ class GraphModule(torch.nn.Module):
lambda params: (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
),
)
def test_associative_scan_cond_in_combine_fn(
@ -4653,7 +4641,6 @@ class GraphModule(torch.nn.Module):
autograd_param=None if not autograd else (x,),
)
@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@ -4672,7 +4659,6 @@ class GraphModule(torch.nn.Module):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -4702,7 +4688,6 @@ class GraphModule(torch.nn.Module):
autograd_param=None if not autograd else elements,
)
@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])

View File

@ -71,7 +71,6 @@ from torch.testing._internal.common_utils import (
markDynamoStrictTest,
parametrize,
run_tests,
skipIfRocm,
skipIfTorchDynamo,
subtest,
TEST_CUDA_MEM_LEAK_CHECK,
@ -5163,7 +5162,6 @@ def traceable(f):
@markDynamoStrictTest
class TestCompileTransforms(TestCase):
@skipIfRocm(msg="test leaks memory on ROCm")
# torch.compile is not supported on Windows CUDA.
# Triton only supports GPU with SM70 or later.
@expectedFailureIf((IS_WINDOWS and TEST_CUDA) or (TEST_CUDA and not SM70OrLater))

View File

@ -20,7 +20,6 @@ from torch.testing._internal.common_utils import (
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
skipIfRocm,
skipIfXpu,
)
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
@ -415,7 +414,6 @@ class AOTInductorTestsTemplate:
self.assertTrue(sentinel_seen)
@skipIfXpu
@skipIfRocm
@unittest.skipIf(IS_FBCODE, "unable to find library -laoti_custom_ops")
def test_custom_op_square(self) -> None:
class Model(torch.nn.Module):

View File

@ -26,7 +26,6 @@ from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfRocm,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.utils._ordered_set import OrderedSet
@ -415,7 +414,6 @@ class LoopOrderingTest(TestCase):
self.do_acc_test(f, x)
self.assertEqual(1, metrics.generated_kernel_count)
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
def test_fp8_cast_and_t(self):
"""
@ -438,7 +436,6 @@ class LoopOrderingTest(TestCase):
self.do_acc_test(f, x, scale)
self.assertEqual(1, metrics.generated_kernel_count)
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
def test_fp8_pattern_2(self):
"""

View File

@ -7,7 +7,6 @@ from torch.multiprocessing.reductions import StorageWeakRef
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocm,
TemporaryDirectoryName,
TestCase,
)
@ -70,7 +69,6 @@ class TestContentStore(TestCase):
for _ in range(4):
hash_storage(torch.tensor(2, device=device).untyped_storage())
@skipIfRocm
def test_load_tensor(self, device):
with TemporaryDirectoryName() as loc:
writer = ContentStoreWriter(loc)

View File

@ -3314,10 +3314,10 @@ exit(2)
@parametrize(
"with_amp,cache_enabled,allow_unused_input",
[
subtest((False, False, True), decorators=[skipIfRocm]),
subtest((True, False, True), decorators=[skipIfRocm]),
subtest((False, False, True)),
subtest((True, False, True)),
subtest((True, True, True), decorators=[unittest.expectedFailure]),
subtest((False, False, False), decorators=[skipIfRocm]),
subtest((False, False, False)),
],
name_fn=lambda x, y, z: "{}{}{}".format(
{True: "with_amp", False: "without_amp"}[x],

View File

@ -31,7 +31,6 @@ from torch.testing._internal.common_utils import (
run_tests,
serialTest,
skipCUDANonDefaultStreamIf,
skipIfRocm,
TEST_CUDA,
TestCase,
)
@ -777,8 +776,6 @@ class TestCudaMultiGPU(TestCase):
p2c.get()
c2p.put(sync_func(self, TestCudaMultiGPU.FIFTY_MIL_CYCLES))
# Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
@skipIfRocm
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
def test_stream_event_nogil(self):
for sync_func in [
@ -819,7 +816,6 @@ class TestCudaMultiGPU(TestCase):
self.assertGreater(parent_time + child_time, total_time * 1.3)
# This test is flaky for ROCm, see issue #62602
@skipIfRocm
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
def test_events_wait(self):
d0 = torch.device("cuda:0")
@ -888,7 +884,6 @@ class TestCudaMultiGPU(TestCase):
self.assertTrue(e1.query())
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
@skipIfRocm
def test_events_multi_gpu_elapsed_time(self):
d0 = torch.device("cuda:0")
d1 = torch.device("cuda:1")

View File

@ -32,13 +32,11 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
skipIfNoDill,
skipIfRocm,
skipIfXpu,
slowTest,
TEST_CUDA,
TEST_NUMPY,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TEST_WITH_TSAN,
TestCase,
xfailIfLinux,
@ -96,7 +94,7 @@ TEST_CUDA_IPC = (
and sys.platform != "darwin"
and sys.platform != "win32"
and not IS_JETSON
and not TEST_WITH_ROCM
# and not TEST_WITH_ROCM
) # https://github.com/pytorch/pytorch/issues/90940
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
@ -1865,7 +1863,6 @@ except RuntimeError as e:
list(iter(ChainDataset([dataset1, self.dataset])))
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
@skipIfRocm # https://github.com/pytorch/pytorch/issues/90940
def test_multiprocessing_contexts(self):
reference = [
torch.arange(3),
@ -2490,7 +2487,6 @@ except RuntimeError as e:
self.assertFalse(pin_memory_thread.is_alive())
# Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065
@skipIfRocm
@unittest.skipIf(not HAS_PSUTIL, "psutil not found")
@slowTest
def test_proper_exit(self):
@ -3134,7 +3130,6 @@ class TestDictDataLoader(TestCase):
self.assertTrue(sample["another_dict"]["a_number"].is_pinned())
@skipIfXpu
@skipIfRocm
@unittest.skipIf(TEST_CUDA, "Test for when CUDA is not available")
def test_pin_memory_no_cuda(self):
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)

View File

@ -1464,7 +1464,6 @@ class FakeTensorOperatorInvariants(TestCase):
self.assertEqual(ref.size(), meta_out.size())
@skipIfRocm
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Does not support SDPA or pre-SM80 hardware",
@ -1526,7 +1525,6 @@ class FakeTensorOperatorInvariants(TestCase):
torch.tensor(3.14, device=GPU_TYPE)
torch.tensor([[3.14, 2], [1, 2]], device=GPU_TYPE)
@skipIfRocm
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_conv_c1_backward(self):
class Repro(torch.nn.Module):

View File

@ -211,7 +211,6 @@ class TestMatmulCuda(TestCase):
self.cublas_addmm(size, dtype, False, True)
@onlyCUDA
@skipIfRocm
def test_cublas_and_lt_reduced_precision_fp16_accumulate(self):
orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation
torch.backends.cuda.matmul.allow_fp16_accumulation = True
@ -739,7 +738,6 @@ class TestMatmulCuda(TestCase):
@onlyCUDA
@skipIfRocm
@parametrize("batch_size", [1, 32])
@parametrize("backend", ["cublas", "cublaslt"])
def test_fp16_accum_and_fp32_out_failure(self, batch_size, backend):

View File

@ -14,7 +14,6 @@ from torch.testing._internal.common_utils import (
run_tests,
gradcheck,
parametrize,
skipIfRocm,
)
@ -231,7 +230,6 @@ class TestSegmentReductions(TestCase):
length_type,
)
@skipIfRocm
@dtypes(
*product(
(torch.half, torch.bfloat16, torch.float, torch.double),

View File

@ -12,7 +12,7 @@ from torch.testing import make_tensor, FileCheck
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
from torch.testing._internal.common_utils import \
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase,
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm,
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo,
skipIfRocmVersionLessThan, IS_FBCODE, IS_REMOTE_GPU, suppress_warnings)
from torch.testing._internal.common_device_type import \
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
@ -3725,7 +3725,6 @@ class TestSparseCompressedTritonKernels(TestCase):
@parametrize("block_size", [16, 32, 64])
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16, torch.float)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")