Revert "[ROCm] Enabling several UTs (#161715)"

This reverts commit b9ba612f7a.

Reverted https://github.com/pytorch/pytorch/pull/161715 on behalf of https://github.com/jeanschmidt due to Need to revert in order to revert https://github.com/pytorch/pytorch/pull/159473, feel free to merge it back once conflicts are cleared ([comment](https://github.com/pytorch/pytorch/pull/161715#issuecomment-3264040604))
This commit is contained in:
PyTorch MergeBot 2025-09-07 21:03:17 +00:00
parent e246a85b76
commit 8235c4f65d
25 changed files with 82 additions and 24 deletions

View File

@ -28,11 +28,7 @@ from torch.testing._internal.common_fsdp import (
patch_reduce_scatter,
reduce_scatter_with_assert,
)
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocmVersionLessThan,
TEST_HPU,
)
from torch.testing._internal.common_utils import run_tests, skipIfRocm, TEST_HPU
device_type = torch.device(get_devtype())
@ -90,7 +86,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
use_shard_placement_fn_vals.append(True)
return use_shard_placement_fn_vals
@skipIfRocmVersionLessThan((7, 0))
@skipIfRocm # regressed in ROCm 6.4, but ROCm 6.5 fixes it
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_compute_dtype(self):
@ -170,7 +166,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
self.assertEqual(fsdp_loss, ref_loss)
check_sharded_parity(self, ref_model, model)
@skipIfRocmVersionLessThan((7, 0))
@skipIfRocm # regressed in ROCm 6.4, but ROCm 6.5 fixes it
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_reduce_dtype(self):

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from torch.distributed._tools.mem_tracker import MemTracker
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocm,
skipIfTorchDynamo,
TEST_CUDA,
TEST_XPU,
@ -33,6 +34,7 @@ class TestMemTracker(TestCase):
@unittest.skipIf(
not TEST_CUDA and not TEST_XPU, "Neither CUDA or XPU is not available"
)
@skipIfRocm()
def test_accelerator_tracker_equivalence(
self,
):

View File

@ -116,6 +116,7 @@ class DistributedUtilTest(TestCase):
timeout=1,
)
@skipIfRocm
def test_create_store_timeout_on_worker(self):
with self.assertRaises(DistNetworkError):
# use any available port (port 0) since timeout is expected

View File

@ -38,6 +38,7 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfRocm,
TEST_WITH_DEV_DBG_ASAN,
)
@ -513,6 +514,7 @@ class TestFSDPOptimState(FSDPTest):
continue
self.assertEqual(full_osd_value, ref_osd_pg[name])
@skipIfRocm
@skip_if_lt_x_gpu(2)
@parametrize("state_dict_type", STATE_DICT_TYPES)
@parametrize("use_multiple_param_groups", [False, True])

View File

@ -24,7 +24,7 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel,
SequenceParallel,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_unless_torch_gpu,
@ -695,6 +695,7 @@ class DistMathOpsTest(DTensorTestBase):
self.assertEqual(grad1_norm.device_mesh, mesh_y)
@with_comms
@skipIfRocm
def test_foreach_add_different_mesh(self):
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(

View File

@ -33,6 +33,7 @@ from torch.testing._internal.common_distributed import (
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
TEST_WITH_DEV_DBG_ASAN,
)
@ -318,6 +319,7 @@ class ProcessGroupNCCLOpTest(MultiProcContinuousTest):
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@skipIfRocm()
def test_nccl_watchdog_cudagraph(self):
# test that the watchdog does not crash graphs with disallowed event query
pg = self.pg

View File

@ -29,6 +29,7 @@ from torch.testing._internal.common_distributed import (
requires_accelerator_dist_backend,
)
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_GPU
@ -362,6 +363,7 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skipIfRocm
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(

View File

@ -8,7 +8,11 @@ from dataclasses import dataclass
import torch
from torch.multiprocessing.reductions import reduce_tensor
from torch.testing._internal.common_distributed import MultiProcContinuousTest
from torch.testing._internal.common_utils import requires_cuda_p2p_access, run_tests
from torch.testing._internal.common_utils import (
requires_cuda_p2p_access,
run_tests,
skipIfRocm,
)
# So that tests are written in device-agnostic way
@ -59,6 +63,7 @@ class CupyAsTensorTest(MultiProcContinuousTest):
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
def test_cupy_as_tensor(self) -> None:
"""
Test that torch.as_tensor works for cupy array interface

View File

@ -39,6 +39,7 @@ from torch.testing._internal.common_distributed import (
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfRocm,
skipIfXpu,
TEST_XPU,
xfailIf,
@ -268,6 +269,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@skipIfRocm
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1728
def test_eager_async_allreduce_inductor_wait(self):
import torch.distributed as dist

View File

@ -7,7 +7,11 @@
import torch
from torch.multiprocessing.reductions import reduce_tensor
from torch.testing._internal.common_distributed import MultiProcContinuousTest
from torch.testing._internal.common_utils import requires_cuda_p2p_access, run_tests
from torch.testing._internal.common_utils import (
requires_cuda_p2p_access,
run_tests,
skipIfRocm,
)
# So that tests are written in device-agnostic way
@ -30,6 +34,7 @@ class P2PIpcTest(MultiProcContinuousTest):
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
def test_p2p_ipc(self) -> None:
"""
Test that cross-process P2P access works, by reducing a tensor,

View File

@ -641,7 +641,7 @@ class SymmMemEmptySetDeviceTest(MultiProcessTestCase):
symm_mem_hdl.barrier()
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(2)
@parametrize("set_device", [True, False])
def test_empty_strided_p2p(self, set_device: bool) -> None:

View File

@ -18,7 +18,7 @@ from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import CompileCounterWithBackend
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.two_tensor import TwoTensor
@ -1364,6 +1364,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
self.assertEqual(out, out_compiled)
self.assertEqual(input.grad, input_compiled.grad)
@skipIfRocm
@requires_cuda_and_triton
def test_autocast_flash_attention(self, device):
def fn(primals_1, primals_2, primals_3):

View File

@ -60,16 +60,13 @@ from torch.testing._internal.common_cuda import (
SM70OrLater,
TEST_CUDA,
)
from torch.testing._internal.common_device_type import (
E4M3_MAX_POS,
e4m3_type,
instantiate_device_type_tests,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
serialTest,
skipIfHpu,
skipIfRocm,
skipIfWindows,
TEST_WITH_ROCM,
)
@ -7479,6 +7476,7 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
out = f_compiled(x, s0, s1, s2)
self.assertEqual(out_ref, out)
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "requires gpu with fp8 support")
@requires_cuda
def test_partitioner_saves_weights_for_bw(self):
@ -7490,9 +7488,9 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
return a
def scale(t, amax_t):
max_v = E4M3_MAX_POS
max_v = torch.finfo(torch.float8_e4m3fn).max
scale_t = torch.clamp(amax_t.float(), min=1e-12) / max_v
t_fp8 = mul_tiled(t, scale_t.reciprocal()).to(e4m3_type)
t_fp8 = mul_tiled(t, scale_t.reciprocal()).to(torch.float8_e4m3fn)
return t_fp8, scale_t
def matmul(first, amax_first, second_t, amax_second_t, bias):

View File

@ -33,6 +33,7 @@ from torch.testing._internal.common_utils import (
requires_cuda,
run_tests,
skipIfCrossRef,
skipIfRocm,
skipIfTorchDynamo,
TEST_WITH_CROSSREF,
TEST_WITH_TORCHDYNAMO,
@ -1861,6 +1862,7 @@ def forward(self, pred_1, x_1):
)
self.assertEqual(grads, expected_grads)
@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("reverse", [False, True])
@ -2005,6 +2007,7 @@ def forward(self, pred_1, x_1):
# TODO: Does not work because of the usage of vmap within associative_scan
# The paT206899919 rameterization is commented out for the moment and the test is marked with expected fail
# Fails with: AssertionError: scan is not an OpOverload
@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@unittest.expectedFailure
@ -3744,6 +3747,7 @@ class AssociativeScanTests(TestCase):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -3810,6 +3814,7 @@ class AssociativeScanTests(TestCase):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -3868,6 +3873,7 @@ class AssociativeScanTests(TestCase):
inputs=x,
)
@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@ -3885,6 +3891,7 @@ class AssociativeScanTests(TestCase):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -3977,6 +3984,7 @@ class AssociativeScanTests(TestCase):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -4160,6 +4168,7 @@ class GraphModule(torch.nn.Module):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -4206,6 +4215,7 @@ class GraphModule(torch.nn.Module):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -4254,6 +4264,7 @@ class GraphModule(torch.nn.Module):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -4433,6 +4444,7 @@ class GraphModule(torch.nn.Module):
lambda params: (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
),
)
def test_associative_scan_cond_in_combine_fn(self, compile_mode, reverse, device):
@ -4550,6 +4562,7 @@ class GraphModule(torch.nn.Module):
inputs=x,
)
@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@ -4567,6 +4580,7 @@ class GraphModule(torch.nn.Module):
and (
params["device"] == torch.device("cpu")
or params["compile_mode"] == "compile_dynamic_shape"
or torch.version.hip
)
),
)
@ -4595,6 +4609,7 @@ class GraphModule(torch.nn.Module):
inputs=elements,
)
@skipIfRocm(msg="Unsupported on ROCM yet")
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])

View File

@ -71,6 +71,7 @@ from torch.testing._internal.common_utils import (
markDynamoStrictTest,
parametrize,
run_tests,
skipIfRocm,
skipIfTorchDynamo,
subtest,
TEST_CUDA_MEM_LEAK_CHECK,
@ -5162,6 +5163,7 @@ def traceable(f):
@markDynamoStrictTest
class TestCompileTransforms(TestCase):
@skipIfRocm(msg="test leaks memory on ROCm")
# torch.compile is not supported on Windows CUDA.
# Triton only supports GPU with SM70 or later.
@expectedFailureIf((IS_WINDOWS and TEST_CUDA) or (TEST_CUDA and not SM70OrLater))

View File

@ -20,6 +20,7 @@ from torch.testing._internal.common_utils import (
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
skipIfRocm,
skipIfXpu,
)
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
@ -414,6 +415,7 @@ class AOTInductorTestsTemplate:
self.assertTrue(sentinel_seen)
@skipIfXpu
@skipIfRocm
@unittest.skipIf(IS_FBCODE, "unable to find library -laoti_custom_ops")
def test_custom_op_square(self) -> None:
class Model(torch.nn.Module):

View File

@ -26,6 +26,7 @@ from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfRocm,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.utils._ordered_set import OrderedSet
@ -414,6 +415,7 @@ class LoopOrderingTest(TestCase):
self.do_acc_test(f, x)
self.assertEqual(1, metrics.generated_kernel_count)
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
def test_fp8_cast_and_t(self):
"""
@ -436,6 +438,7 @@ class LoopOrderingTest(TestCase):
self.do_acc_test(f, x, scale)
self.assertEqual(1, metrics.generated_kernel_count)
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
def test_fp8_pattern_2(self):
"""

View File

@ -7,6 +7,7 @@ from torch.multiprocessing.reductions import StorageWeakRef
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocm,
TemporaryDirectoryName,
TestCase,
)
@ -69,6 +70,7 @@ class TestContentStore(TestCase):
for _ in range(4):
hash_storage(torch.tensor(2, device=device).untyped_storage())
@skipIfRocm
def test_load_tensor(self, device):
with TemporaryDirectoryName() as loc:
writer = ContentStoreWriter(loc)

View File

@ -3314,10 +3314,10 @@ exit(2)
@parametrize(
"with_amp,cache_enabled,allow_unused_input",
[
subtest((False, False, True)),
subtest((True, False, True)),
subtest((False, False, True), decorators=[skipIfRocm]),
subtest((True, False, True), decorators=[skipIfRocm]),
subtest((True, True, True), decorators=[unittest.expectedFailure]),
subtest((False, False, False)),
subtest((False, False, False), decorators=[skipIfRocm]),
],
name_fn=lambda x, y, z: "{}{}{}".format(
{True: "with_amp", False: "without_amp"}[x],

View File

@ -31,6 +31,7 @@ from torch.testing._internal.common_utils import (
run_tests,
serialTest,
skipCUDANonDefaultStreamIf,
skipIfRocm,
TEST_CUDA,
TestCase,
)
@ -776,6 +777,8 @@ class TestCudaMultiGPU(TestCase):
p2c.get()
c2p.put(sync_func(self, TestCudaMultiGPU.FIFTY_MIL_CYCLES))
# Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
@skipIfRocm
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
def test_stream_event_nogil(self):
for sync_func in [
@ -816,6 +819,7 @@ class TestCudaMultiGPU(TestCase):
self.assertGreater(parent_time + child_time, total_time * 1.3)
# This test is flaky for ROCm, see issue #62602
@skipIfRocm
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
def test_events_wait(self):
d0 = torch.device("cuda:0")
@ -884,6 +888,7 @@ class TestCudaMultiGPU(TestCase):
self.assertTrue(e1.query())
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
@skipIfRocm
def test_events_multi_gpu_elapsed_time(self):
d0 = torch.device("cuda:0")
d1 = torch.device("cuda:1")

View File

@ -32,11 +32,13 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
skipIfNoDill,
skipIfRocm,
skipIfXpu,
slowTest,
TEST_CUDA,
TEST_NUMPY,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TEST_WITH_TSAN,
TestCase,
xfailIfLinux,
@ -94,7 +96,7 @@ TEST_CUDA_IPC = (
and sys.platform != "darwin"
and sys.platform != "win32"
and not IS_JETSON
# and not TEST_WITH_ROCM
and not TEST_WITH_ROCM
) # https://github.com/pytorch/pytorch/issues/90940
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
@ -1863,6 +1865,7 @@ except RuntimeError as e:
list(iter(ChainDataset([dataset1, self.dataset])))
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
@skipIfRocm # https://github.com/pytorch/pytorch/issues/90940
def test_multiprocessing_contexts(self):
reference = [
torch.arange(3),
@ -2487,6 +2490,7 @@ except RuntimeError as e:
self.assertFalse(pin_memory_thread.is_alive())
# Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065
@skipIfRocm
@unittest.skipIf(not HAS_PSUTIL, "psutil not found")
@slowTest
def test_proper_exit(self):
@ -3130,6 +3134,7 @@ class TestDictDataLoader(TestCase):
self.assertTrue(sample["another_dict"]["a_number"].is_pinned())
@skipIfXpu
@skipIfRocm
@unittest.skipIf(TEST_CUDA, "Test for when CUDA is not available")
def test_pin_memory_no_cuda(self):
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)

View File

@ -1464,6 +1464,7 @@ class FakeTensorOperatorInvariants(TestCase):
self.assertEqual(ref.size(), meta_out.size())
@skipIfRocm
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Does not support SDPA or pre-SM80 hardware",
@ -1525,6 +1526,7 @@ class FakeTensorOperatorInvariants(TestCase):
torch.tensor(3.14, device=GPU_TYPE)
torch.tensor([[3.14, 2], [1, 2]], device=GPU_TYPE)
@skipIfRocm
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_conv_c1_backward(self):
class Repro(torch.nn.Module):

View File

@ -211,6 +211,7 @@ class TestMatmulCuda(TestCase):
self.cublas_addmm(size, dtype, False, True)
@onlyCUDA
@skipIfRocm
def test_cublas_and_lt_reduced_precision_fp16_accumulate(self):
orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation
torch.backends.cuda.matmul.allow_fp16_accumulation = True
@ -738,6 +739,7 @@ class TestMatmulCuda(TestCase):
@onlyCUDA
@skipIfRocm
@parametrize("batch_size", [1, 32])
@parametrize("backend", ["cublas", "cublaslt"])
def test_fp16_accum_and_fp32_out_failure(self, batch_size, backend):

View File

@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import (
run_tests,
gradcheck,
parametrize,
skipIfRocm,
)
@ -230,6 +231,7 @@ class TestSegmentReductions(TestCase):
length_type,
)
@skipIfRocm
@dtypes(
*product(
(torch.half, torch.bfloat16, torch.float, torch.double),

View File

@ -12,7 +12,7 @@ from torch.testing import make_tensor, FileCheck
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
from torch.testing._internal.common_utils import \
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase,
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo,
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm,
skipIfRocmVersionLessThan, IS_FBCODE, IS_REMOTE_GPU, suppress_warnings)
from torch.testing._internal.common_device_type import \
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
@ -3725,6 +3725,7 @@ class TestSparseCompressedTritonKernels(TestCase):
@parametrize("block_size", [16, 32, 64])
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16, torch.float)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")