mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Enabling several UTs (#161715)
All these UTs are working as is, just removing the skip - test_p2p_ipc - test_repros.py: working, added fp8 support - test_activation_checkpointing.py - test_content_store.py - test_cuda_multigpu.py - test_compute_comm_reordering.py - test_segment_reductions.py - test_dataloader.py - test_math_ops.py - test_loop_ordering.py - test_control_flow.py - distributed_test.py - test_mem_tracker.py - test_fsdp_optim_state.py - test_fully_shard_mixed_precision.py: skippped for < ROCm7.0 - test_aot_inductor_custom_ops.py - test_c10d_ops_nccl.py - test_eager_transforms.py - test_sparse_csr.py - test_inductor_collectives.py - test_fake_tensor.py - test_cupy_as_tensor.py - test_cuda.py: enable UTs that are working - test_matmul_cuda.py: enable UTs that are working Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/161715 Approved by: https://github.com/msaroufim Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
This commit is contained in:
parent
3ea6868049
commit
c0142f5c06
|
|
@ -28,7 +28,11 @@ from torch.testing._internal.common_fsdp import (
|
||||||
patch_reduce_scatter,
|
patch_reduce_scatter,
|
||||||
reduce_scatter_with_assert,
|
reduce_scatter_with_assert,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm, TEST_HPU
|
from torch.testing._internal.common_utils import (
|
||||||
|
run_tests,
|
||||||
|
skipIfRocmVersionLessThan,
|
||||||
|
TEST_HPU,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
device_type = torch.device(get_devtype())
|
device_type = torch.device(get_devtype())
|
||||||
|
|
@ -86,7 +90,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
|
||||||
use_shard_placement_fn_vals.append(True)
|
use_shard_placement_fn_vals.append(True)
|
||||||
return use_shard_placement_fn_vals
|
return use_shard_placement_fn_vals
|
||||||
|
|
||||||
@skipIfRocm # regressed in ROCm 6.4, but ROCm 6.5 fixes it
|
@skipIfRocmVersionLessThan((7, 0))
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||||
def test_compute_dtype(self):
|
def test_compute_dtype(self):
|
||||||
|
|
@ -166,7 +170,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
|
||||||
self.assertEqual(fsdp_loss, ref_loss)
|
self.assertEqual(fsdp_loss, ref_loss)
|
||||||
check_sharded_parity(self, ref_model, model)
|
check_sharded_parity(self, ref_model, model)
|
||||||
|
|
||||||
@skipIfRocm # regressed in ROCm 6.4, but ROCm 6.5 fixes it
|
@skipIfRocmVersionLessThan((7, 0))
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||||
def test_reduce_dtype(self):
|
def test_reduce_dtype(self):
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ import torch.nn as nn
|
||||||
from torch.distributed._tools.mem_tracker import MemTracker
|
from torch.distributed._tools.mem_tracker import MemTracker
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
run_tests,
|
||||||
skipIfRocm,
|
|
||||||
skipIfTorchDynamo,
|
skipIfTorchDynamo,
|
||||||
TEST_CUDA,
|
TEST_CUDA,
|
||||||
TEST_XPU,
|
TEST_XPU,
|
||||||
|
|
@ -34,7 +33,6 @@ class TestMemTracker(TestCase):
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
not TEST_CUDA and not TEST_XPU, "Neither CUDA or XPU is not available"
|
not TEST_CUDA and not TEST_XPU, "Neither CUDA or XPU is not available"
|
||||||
)
|
)
|
||||||
@skipIfRocm()
|
|
||||||
def test_accelerator_tracker_equivalence(
|
def test_accelerator_tracker_equivalence(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,6 @@ class DistributedUtilTest(TestCase):
|
||||||
timeout=1,
|
timeout=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
def test_create_store_timeout_on_worker(self):
|
def test_create_store_timeout_on_worker(self):
|
||||||
with self.assertRaises(DistNetworkError):
|
with self.assertRaises(DistNetworkError):
|
||||||
# use any available port (port 0) since timeout is expected
|
# use any available port (port 0) since timeout is expected
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,6 @@ from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
skipIfRocm,
|
|
||||||
TEST_WITH_DEV_DBG_ASAN,
|
TEST_WITH_DEV_DBG_ASAN,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -514,7 +513,6 @@ class TestFSDPOptimState(FSDPTest):
|
||||||
continue
|
continue
|
||||||
self.assertEqual(full_osd_value, ref_osd_pg[name])
|
self.assertEqual(full_osd_value, ref_osd_pg[name])
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@parametrize("state_dict_type", STATE_DICT_TYPES)
|
@parametrize("state_dict_type", STATE_DICT_TYPES)
|
||||||
@parametrize("use_multiple_param_groups", [False, True])
|
@parametrize("use_multiple_param_groups", [False, True])
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from torch.distributed.tensor.parallel import (
|
||||||
RowwiseParallel,
|
RowwiseParallel,
|
||||||
SequenceParallel,
|
SequenceParallel,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
from torch.testing._internal.common_utils import run_tests
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||||
DTensorTestBase,
|
DTensorTestBase,
|
||||||
skip_unless_torch_gpu,
|
skip_unless_torch_gpu,
|
||||||
|
|
@ -695,7 +695,6 @@ class DistMathOpsTest(DTensorTestBase):
|
||||||
self.assertEqual(grad1_norm.device_mesh, mesh_y)
|
self.assertEqual(grad1_norm.device_mesh, mesh_y)
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
@skipIfRocm
|
|
||||||
def test_foreach_add_different_mesh(self):
|
def test_foreach_add_different_mesh(self):
|
||||||
mesh_shape = (2, self.world_size // 2)
|
mesh_shape = (2, self.world_size // 2)
|
||||||
mesh_2d = init_device_mesh(
|
mesh_2d = init_device_mesh(
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,6 @@ from torch.testing._internal.common_distributed import (
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
run_tests,
|
||||||
skip_but_pass_in_sandcastle_if,
|
skip_but_pass_in_sandcastle_if,
|
||||||
skipIfRocm,
|
|
||||||
TEST_WITH_DEV_DBG_ASAN,
|
TEST_WITH_DEV_DBG_ASAN,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -319,7 +318,6 @@ class ProcessGroupNCCLOpTest(MultiProcContinuousTest):
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||||
@skipIfRocm()
|
|
||||||
def test_nccl_watchdog_cudagraph(self):
|
def test_nccl_watchdog_cudagraph(self):
|
||||||
# test that the watchdog does not crash graphs with disallowed event query
|
# test that the watchdog does not crash graphs with disallowed event query
|
||||||
pg = self.pg
|
pg = self.pg
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,6 @@ from torch.testing._internal.common_distributed import (
|
||||||
requires_accelerator_dist_backend,
|
requires_accelerator_dist_backend,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_fsdp import get_devtype
|
from torch.testing._internal.common_fsdp import get_devtype
|
||||||
from torch.testing._internal.common_utils import skipIfRocm
|
|
||||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -368,7 +367,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
self.assertTrue(same(out, correct))
|
self.assertTrue(same(out, correct))
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@skipIfRocm
|
|
||||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||||
@patch.object(
|
@patch.object(
|
||||||
|
|
|
||||||
|
|
@ -8,11 +8,7 @@ from dataclasses import dataclass
|
||||||
import torch
|
import torch
|
||||||
from torch.multiprocessing.reductions import reduce_tensor
|
from torch.multiprocessing.reductions import reduce_tensor
|
||||||
from torch.testing._internal.common_distributed import MultiProcContinuousTest
|
from torch.testing._internal.common_distributed import MultiProcContinuousTest
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import requires_cuda_p2p_access, run_tests
|
||||||
requires_cuda_p2p_access,
|
|
||||||
run_tests,
|
|
||||||
skipIfRocm,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# So that tests are written in device-agnostic way
|
# So that tests are written in device-agnostic way
|
||||||
|
|
@ -63,7 +59,6 @@ class CupyAsTensorTest(MultiProcContinuousTest):
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return torch.device(device_type, self.rank)
|
return torch.device(device_type, self.rank)
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
def test_cupy_as_tensor(self) -> None:
|
def test_cupy_as_tensor(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that torch.as_tensor works for cupy array interface
|
Test that torch.as_tensor works for cupy array interface
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,8 @@ from torch.testing._internal.common_utils import (
|
||||||
parametrize,
|
parametrize,
|
||||||
requires_cuda,
|
requires_cuda,
|
||||||
skipIfRocm,
|
skipIfRocm,
|
||||||
|
TEST_XPU,
|
||||||
|
xfailIf,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
@ -266,6 +268,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1728
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_eager_async_allreduce_inductor_wait(self):
|
def test_eager_async_allreduce_inductor_wait(self):
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.multiprocessing.reductions import reduce_tensor
|
from torch.multiprocessing.reductions import reduce_tensor
|
||||||
from torch.testing._internal.common_distributed import MultiProcContinuousTest
|
from torch.testing._internal.common_distributed import MultiProcContinuousTest
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import requires_cuda_p2p_access, run_tests
|
||||||
requires_cuda_p2p_access,
|
|
||||||
run_tests,
|
|
||||||
skipIfRocm,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# So that tests are written in device-agnostic way
|
# So that tests are written in device-agnostic way
|
||||||
|
|
@ -34,7 +30,6 @@ class P2PIpcTest(MultiProcContinuousTest):
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return torch.device(device_type, self.rank)
|
return torch.device(device_type, self.rank)
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
def test_p2p_ipc(self) -> None:
|
def test_p2p_ipc(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that cross-process P2P access works, by reducing a tensor,
|
Test that cross-process P2P access works, by reducing a tensor,
|
||||||
|
|
|
||||||
|
|
@ -644,7 +644,7 @@ class SymmMemEmptySetDeviceTest(MultiProcessTestCase):
|
||||||
|
|
||||||
symm_mem_hdl.barrier()
|
symm_mem_hdl.barrier()
|
||||||
|
|
||||||
@runOnRocmArch(MI300_ARCH)
|
@skipIfRocm
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@parametrize("set_device", [True, False])
|
@parametrize("set_device", [True, False])
|
||||||
def test_empty_strided_p2p(self, set_device: bool) -> None:
|
def test_empty_strided_p2p(self, set_device: bool) -> None:
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from torch._dynamo.backends.common import aot_autograd
|
||||||
from torch._dynamo.testing import CompileCounterWithBackend
|
from torch._dynamo.testing import CompileCounterWithBackend
|
||||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm
|
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
|
||||||
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
||||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||||
from torch.testing._internal.two_tensor import TwoTensor
|
from torch.testing._internal.two_tensor import TwoTensor
|
||||||
|
|
@ -1364,7 +1364,6 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||||
self.assertEqual(out, out_compiled)
|
self.assertEqual(out, out_compiled)
|
||||||
self.assertEqual(input.grad, input_compiled.grad)
|
self.assertEqual(input.grad, input_compiled.grad)
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
@requires_cuda_and_triton
|
@requires_cuda_and_triton
|
||||||
def test_autocast_flash_attention(self, device):
|
def test_autocast_flash_attention(self, device):
|
||||||
def fn(primals_1, primals_2, primals_3):
|
def fn(primals_1, primals_2, primals_3):
|
||||||
|
|
|
||||||
|
|
@ -60,13 +60,16 @@ from torch.testing._internal.common_cuda import (
|
||||||
SM70OrLater,
|
SM70OrLater,
|
||||||
TEST_CUDA,
|
TEST_CUDA,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
from torch.testing._internal.common_device_type import (
|
||||||
|
E4M3_MAX_POS,
|
||||||
|
e4m3_type,
|
||||||
|
instantiate_device_type_tests,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
serialTest,
|
serialTest,
|
||||||
skipIfHpu,
|
skipIfHpu,
|
||||||
skipIfRocm,
|
|
||||||
skipIfWindows,
|
skipIfWindows,
|
||||||
TEST_WITH_ROCM,
|
TEST_WITH_ROCM,
|
||||||
)
|
)
|
||||||
|
|
@ -7500,7 +7503,6 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||||
out = f_compiled(x, s0, s1, s2)
|
out = f_compiled(x, s0, s1, s2)
|
||||||
self.assertEqual(out_ref, out)
|
self.assertEqual(out_ref, out)
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "requires gpu with fp8 support")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "requires gpu with fp8 support")
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
def test_partitioner_saves_weights_for_bw(self):
|
def test_partitioner_saves_weights_for_bw(self):
|
||||||
|
|
@ -7512,9 +7514,9 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def scale(t, amax_t):
|
def scale(t, amax_t):
|
||||||
max_v = torch.finfo(torch.float8_e4m3fn).max
|
max_v = E4M3_MAX_POS
|
||||||
scale_t = torch.clamp(amax_t.float(), min=1e-12) / max_v
|
scale_t = torch.clamp(amax_t.float(), min=1e-12) / max_v
|
||||||
t_fp8 = mul_tiled(t, scale_t.reciprocal()).to(torch.float8_e4m3fn)
|
t_fp8 = mul_tiled(t, scale_t.reciprocal()).to(e4m3_type)
|
||||||
return t_fp8, scale_t
|
return t_fp8, scale_t
|
||||||
|
|
||||||
def matmul(first, amax_first, second_t, amax_second_t, bias):
|
def matmul(first, amax_first, second_t, amax_second_t, bias):
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,6 @@ from torch.testing._internal.common_utils import (
|
||||||
requires_cuda,
|
requires_cuda,
|
||||||
run_tests,
|
run_tests,
|
||||||
skipIfCrossRef,
|
skipIfCrossRef,
|
||||||
skipIfRocm,
|
|
||||||
skipIfTorchDynamo,
|
skipIfTorchDynamo,
|
||||||
TEST_WITH_CROSSREF,
|
TEST_WITH_CROSSREF,
|
||||||
TEST_WITH_TORCHDYNAMO,
|
TEST_WITH_TORCHDYNAMO,
|
||||||
|
|
@ -1862,7 +1861,6 @@ def forward(self, pred_1, x_1):
|
||||||
)
|
)
|
||||||
self.assertEqual(grads, expected_grads)
|
self.assertEqual(grads, expected_grads)
|
||||||
|
|
||||||
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
||||||
@unittest.skipIf(not SM70OrLater, "triton")
|
@unittest.skipIf(not SM70OrLater, "triton")
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
@parametrize("reverse", [False, True])
|
@parametrize("reverse", [False, True])
|
||||||
|
|
@ -2007,7 +2005,6 @@ def forward(self, pred_1, x_1):
|
||||||
# TODO: Does not work because of the usage of vmap within associative_scan
|
# TODO: Does not work because of the usage of vmap within associative_scan
|
||||||
# The paT206899919 rameterization is commented out for the moment and the test is marked with expected fail
|
# The paT206899919 rameterization is commented out for the moment and the test is marked with expected fail
|
||||||
# Fails with: AssertionError: scan is not an OpOverload
|
# Fails with: AssertionError: scan is not an OpOverload
|
||||||
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
||||||
@unittest.skipIf(not SM70OrLater, "triton")
|
@unittest.skipIf(not SM70OrLater, "triton")
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
|
|
@ -3775,7 +3772,6 @@ class AssociativeScanTests(TestCase):
|
||||||
and (
|
and (
|
||||||
params["device"] == torch.device("cpu")
|
params["device"] == torch.device("cpu")
|
||||||
or params["compile_mode"] == "compile_dynamic_shape"
|
or params["compile_mode"] == "compile_dynamic_shape"
|
||||||
or torch.version.hip
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -3859,7 +3855,6 @@ class AssociativeScanTests(TestCase):
|
||||||
and (
|
and (
|
||||||
params["device"] == torch.device("cpu")
|
params["device"] == torch.device("cpu")
|
||||||
or params["compile_mode"] == "compile_dynamic_shape"
|
or params["compile_mode"] == "compile_dynamic_shape"
|
||||||
or torch.version.hip
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -3921,7 +3916,6 @@ class AssociativeScanTests(TestCase):
|
||||||
inputs=x,
|
inputs=x,
|
||||||
)
|
)
|
||||||
|
|
||||||
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
||||||
@unittest.skipIf(not SM70OrLater, "triton")
|
@unittest.skipIf(not SM70OrLater, "triton")
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||||
|
|
@ -3940,7 +3934,6 @@ class AssociativeScanTests(TestCase):
|
||||||
and (
|
and (
|
||||||
params["device"] == torch.device("cpu")
|
params["device"] == torch.device("cpu")
|
||||||
or params["compile_mode"] == "compile_dynamic_shape"
|
or params["compile_mode"] == "compile_dynamic_shape"
|
||||||
or torch.version.hip
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -4044,7 +4037,6 @@ class AssociativeScanTests(TestCase):
|
||||||
and (
|
and (
|
||||||
params["device"] == torch.device("cpu")
|
params["device"] == torch.device("cpu")
|
||||||
or params["compile_mode"] == "compile_dynamic_shape"
|
or params["compile_mode"] == "compile_dynamic_shape"
|
||||||
or torch.version.hip
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -4230,7 +4222,6 @@ class GraphModule(torch.nn.Module):
|
||||||
and (
|
and (
|
||||||
params["device"] == torch.device("cpu")
|
params["device"] == torch.device("cpu")
|
||||||
or params["compile_mode"] == "compile_dynamic_shape"
|
or params["compile_mode"] == "compile_dynamic_shape"
|
||||||
or torch.version.hip
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -4279,7 +4270,6 @@ class GraphModule(torch.nn.Module):
|
||||||
and (
|
and (
|
||||||
params["device"] == torch.device("cpu")
|
params["device"] == torch.device("cpu")
|
||||||
or params["compile_mode"] == "compile_dynamic_shape"
|
or params["compile_mode"] == "compile_dynamic_shape"
|
||||||
or torch.version.hip
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -4330,7 +4320,6 @@ class GraphModule(torch.nn.Module):
|
||||||
and (
|
and (
|
||||||
params["device"] == torch.device("cpu")
|
params["device"] == torch.device("cpu")
|
||||||
or params["compile_mode"] == "compile_dynamic_shape"
|
or params["compile_mode"] == "compile_dynamic_shape"
|
||||||
or torch.version.hip
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -4526,7 +4515,6 @@ class GraphModule(torch.nn.Module):
|
||||||
lambda params: (
|
lambda params: (
|
||||||
params["device"] == torch.device("cpu")
|
params["device"] == torch.device("cpu")
|
||||||
or params["compile_mode"] == "compile_dynamic_shape"
|
or params["compile_mode"] == "compile_dynamic_shape"
|
||||||
or torch.version.hip
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def test_associative_scan_cond_in_combine_fn(
|
def test_associative_scan_cond_in_combine_fn(
|
||||||
|
|
@ -4653,7 +4641,6 @@ class GraphModule(torch.nn.Module):
|
||||||
autograd_param=None if not autograd else (x,),
|
autograd_param=None if not autograd else (x,),
|
||||||
)
|
)
|
||||||
|
|
||||||
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
||||||
@unittest.skipIf(not SM70OrLater, "triton")
|
@unittest.skipIf(not SM70OrLater, "triton")
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||||
|
|
@ -4672,7 +4659,6 @@ class GraphModule(torch.nn.Module):
|
||||||
and (
|
and (
|
||||||
params["device"] == torch.device("cpu")
|
params["device"] == torch.device("cpu")
|
||||||
or params["compile_mode"] == "compile_dynamic_shape"
|
or params["compile_mode"] == "compile_dynamic_shape"
|
||||||
or torch.version.hip
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -4702,7 +4688,6 @@ class GraphModule(torch.nn.Module):
|
||||||
autograd_param=None if not autograd else elements,
|
autograd_param=None if not autograd else elements,
|
||||||
)
|
)
|
||||||
|
|
||||||
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
||||||
@unittest.skipIf(not SM70OrLater, "triton")
|
@unittest.skipIf(not SM70OrLater, "triton")
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,6 @@ from torch.testing._internal.common_utils import (
|
||||||
markDynamoStrictTest,
|
markDynamoStrictTest,
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
skipIfRocm,
|
|
||||||
skipIfTorchDynamo,
|
skipIfTorchDynamo,
|
||||||
subtest,
|
subtest,
|
||||||
TEST_CUDA_MEM_LEAK_CHECK,
|
TEST_CUDA_MEM_LEAK_CHECK,
|
||||||
|
|
@ -5163,7 +5162,6 @@ def traceable(f):
|
||||||
|
|
||||||
@markDynamoStrictTest
|
@markDynamoStrictTest
|
||||||
class TestCompileTransforms(TestCase):
|
class TestCompileTransforms(TestCase):
|
||||||
@skipIfRocm(msg="test leaks memory on ROCm")
|
|
||||||
# torch.compile is not supported on Windows CUDA.
|
# torch.compile is not supported on Windows CUDA.
|
||||||
# Triton only supports GPU with SM70 or later.
|
# Triton only supports GPU with SM70 or later.
|
||||||
@expectedFailureIf((IS_WINDOWS and TEST_CUDA) or (TEST_CUDA and not SM70OrLater))
|
@expectedFailureIf((IS_WINDOWS and TEST_CUDA) or (TEST_CUDA and not SM70OrLater))
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ from torch.testing._internal.common_utils import (
|
||||||
IS_MACOS,
|
IS_MACOS,
|
||||||
IS_SANDCASTLE,
|
IS_SANDCASTLE,
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
skipIfRocm,
|
|
||||||
skipIfXpu,
|
skipIfXpu,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
||||||
|
|
@ -415,7 +414,6 @@ class AOTInductorTestsTemplate:
|
||||||
self.assertTrue(sentinel_seen)
|
self.assertTrue(sentinel_seen)
|
||||||
|
|
||||||
@skipIfXpu
|
@skipIfXpu
|
||||||
@skipIfRocm
|
|
||||||
@unittest.skipIf(IS_FBCODE, "unable to find library -laoti_custom_ops")
|
@unittest.skipIf(IS_FBCODE, "unable to find library -laoti_custom_ops")
|
||||||
def test_custom_op_square(self) -> None:
|
def test_custom_op_square(self) -> None:
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@ from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
skipIfRocm,
|
|
||||||
)
|
)
|
||||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
@ -415,7 +414,6 @@ class LoopOrderingTest(TestCase):
|
||||||
self.do_acc_test(f, x)
|
self.do_acc_test(f, x)
|
||||||
self.assertEqual(1, metrics.generated_kernel_count)
|
self.assertEqual(1, metrics.generated_kernel_count)
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
|
||||||
def test_fp8_cast_and_t(self):
|
def test_fp8_cast_and_t(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -438,7 +436,6 @@ class LoopOrderingTest(TestCase):
|
||||||
self.do_acc_test(f, x, scale)
|
self.do_acc_test(f, x, scale)
|
||||||
self.assertEqual(1, metrics.generated_kernel_count)
|
self.assertEqual(1, metrics.generated_kernel_count)
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
|
||||||
def test_fp8_pattern_2(self):
|
def test_fp8_pattern_2(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from torch.multiprocessing.reductions import StorageWeakRef
|
||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
run_tests,
|
||||||
skipIfRocm,
|
|
||||||
TemporaryDirectoryName,
|
TemporaryDirectoryName,
|
||||||
TestCase,
|
TestCase,
|
||||||
)
|
)
|
||||||
|
|
@ -70,7 +69,6 @@ class TestContentStore(TestCase):
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
hash_storage(torch.tensor(2, device=device).untyped_storage())
|
hash_storage(torch.tensor(2, device=device).untyped_storage())
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
def test_load_tensor(self, device):
|
def test_load_tensor(self, device):
|
||||||
with TemporaryDirectoryName() as loc:
|
with TemporaryDirectoryName() as loc:
|
||||||
writer = ContentStoreWriter(loc)
|
writer = ContentStoreWriter(loc)
|
||||||
|
|
|
||||||
|
|
@ -3314,10 +3314,10 @@ exit(2)
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"with_amp,cache_enabled,allow_unused_input",
|
"with_amp,cache_enabled,allow_unused_input",
|
||||||
[
|
[
|
||||||
subtest((False, False, True), decorators=[skipIfRocm]),
|
subtest((False, False, True)),
|
||||||
subtest((True, False, True), decorators=[skipIfRocm]),
|
subtest((True, False, True)),
|
||||||
subtest((True, True, True), decorators=[unittest.expectedFailure]),
|
subtest((True, True, True), decorators=[unittest.expectedFailure]),
|
||||||
subtest((False, False, False), decorators=[skipIfRocm]),
|
subtest((False, False, False)),
|
||||||
],
|
],
|
||||||
name_fn=lambda x, y, z: "{}{}{}".format(
|
name_fn=lambda x, y, z: "{}{}{}".format(
|
||||||
{True: "with_amp", False: "without_amp"}[x],
|
{True: "with_amp", False: "without_amp"}[x],
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,6 @@ from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
run_tests,
|
||||||
serialTest,
|
serialTest,
|
||||||
skipCUDANonDefaultStreamIf,
|
skipCUDANonDefaultStreamIf,
|
||||||
skipIfRocm,
|
|
||||||
TEST_CUDA,
|
TEST_CUDA,
|
||||||
TestCase,
|
TestCase,
|
||||||
)
|
)
|
||||||
|
|
@ -777,8 +776,6 @@ class TestCudaMultiGPU(TestCase):
|
||||||
p2c.get()
|
p2c.get()
|
||||||
c2p.put(sync_func(self, TestCudaMultiGPU.FIFTY_MIL_CYCLES))
|
c2p.put(sync_func(self, TestCudaMultiGPU.FIFTY_MIL_CYCLES))
|
||||||
|
|
||||||
# Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
|
|
||||||
@skipIfRocm
|
|
||||||
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
||||||
def test_stream_event_nogil(self):
|
def test_stream_event_nogil(self):
|
||||||
for sync_func in [
|
for sync_func in [
|
||||||
|
|
@ -819,7 +816,6 @@ class TestCudaMultiGPU(TestCase):
|
||||||
self.assertGreater(parent_time + child_time, total_time * 1.3)
|
self.assertGreater(parent_time + child_time, total_time * 1.3)
|
||||||
|
|
||||||
# This test is flaky for ROCm, see issue #62602
|
# This test is flaky for ROCm, see issue #62602
|
||||||
@skipIfRocm
|
|
||||||
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
||||||
def test_events_wait(self):
|
def test_events_wait(self):
|
||||||
d0 = torch.device("cuda:0")
|
d0 = torch.device("cuda:0")
|
||||||
|
|
@ -888,7 +884,6 @@ class TestCudaMultiGPU(TestCase):
|
||||||
self.assertTrue(e1.query())
|
self.assertTrue(e1.query())
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
||||||
@skipIfRocm
|
|
||||||
def test_events_multi_gpu_elapsed_time(self):
|
def test_events_multi_gpu_elapsed_time(self):
|
||||||
d0 = torch.device("cuda:0")
|
d0 = torch.device("cuda:0")
|
||||||
d1 = torch.device("cuda:1")
|
d1 = torch.device("cuda:1")
|
||||||
|
|
|
||||||
|
|
@ -32,13 +32,11 @@ from torch.testing._internal.common_utils import (
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
skipIfNoDill,
|
skipIfNoDill,
|
||||||
skipIfRocm,
|
|
||||||
skipIfXpu,
|
skipIfXpu,
|
||||||
slowTest,
|
slowTest,
|
||||||
TEST_CUDA,
|
TEST_CUDA,
|
||||||
TEST_NUMPY,
|
TEST_NUMPY,
|
||||||
TEST_WITH_ASAN,
|
TEST_WITH_ASAN,
|
||||||
TEST_WITH_ROCM,
|
|
||||||
TEST_WITH_TSAN,
|
TEST_WITH_TSAN,
|
||||||
TestCase,
|
TestCase,
|
||||||
xfailIfLinux,
|
xfailIfLinux,
|
||||||
|
|
@ -96,7 +94,7 @@ TEST_CUDA_IPC = (
|
||||||
and sys.platform != "darwin"
|
and sys.platform != "darwin"
|
||||||
and sys.platform != "win32"
|
and sys.platform != "win32"
|
||||||
and not IS_JETSON
|
and not IS_JETSON
|
||||||
and not TEST_WITH_ROCM
|
# and not TEST_WITH_ROCM
|
||||||
) # https://github.com/pytorch/pytorch/issues/90940
|
) # https://github.com/pytorch/pytorch/issues/90940
|
||||||
|
|
||||||
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
|
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
|
||||||
|
|
@ -1865,7 +1863,6 @@ except RuntimeError as e:
|
||||||
list(iter(ChainDataset([dataset1, self.dataset])))
|
list(iter(ChainDataset([dataset1, self.dataset])))
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
|
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
|
||||||
@skipIfRocm # https://github.com/pytorch/pytorch/issues/90940
|
|
||||||
def test_multiprocessing_contexts(self):
|
def test_multiprocessing_contexts(self):
|
||||||
reference = [
|
reference = [
|
||||||
torch.arange(3),
|
torch.arange(3),
|
||||||
|
|
@ -2490,7 +2487,6 @@ except RuntimeError as e:
|
||||||
self.assertFalse(pin_memory_thread.is_alive())
|
self.assertFalse(pin_memory_thread.is_alive())
|
||||||
|
|
||||||
# Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065
|
# Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065
|
||||||
@skipIfRocm
|
|
||||||
@unittest.skipIf(not HAS_PSUTIL, "psutil not found")
|
@unittest.skipIf(not HAS_PSUTIL, "psutil not found")
|
||||||
@slowTest
|
@slowTest
|
||||||
def test_proper_exit(self):
|
def test_proper_exit(self):
|
||||||
|
|
@ -3134,7 +3130,6 @@ class TestDictDataLoader(TestCase):
|
||||||
self.assertTrue(sample["another_dict"]["a_number"].is_pinned())
|
self.assertTrue(sample["another_dict"]["a_number"].is_pinned())
|
||||||
|
|
||||||
@skipIfXpu
|
@skipIfXpu
|
||||||
@skipIfRocm
|
|
||||||
@unittest.skipIf(TEST_CUDA, "Test for when CUDA is not available")
|
@unittest.skipIf(TEST_CUDA, "Test for when CUDA is not available")
|
||||||
def test_pin_memory_no_cuda(self):
|
def test_pin_memory_no_cuda(self):
|
||||||
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
|
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
|
||||||
|
|
|
||||||
|
|
@ -1464,7 +1464,6 @@ class FakeTensorOperatorInvariants(TestCase):
|
||||||
|
|
||||||
self.assertEqual(ref.size(), meta_out.size())
|
self.assertEqual(ref.size(), meta_out.size())
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||||
"Does not support SDPA or pre-SM80 hardware",
|
"Does not support SDPA or pre-SM80 hardware",
|
||||||
|
|
@ -1526,7 +1525,6 @@ class FakeTensorOperatorInvariants(TestCase):
|
||||||
torch.tensor(3.14, device=GPU_TYPE)
|
torch.tensor(3.14, device=GPU_TYPE)
|
||||||
torch.tensor([[3.14, 2], [1, 2]], device=GPU_TYPE)
|
torch.tensor([[3.14, 2], [1, 2]], device=GPU_TYPE)
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||||
def test_conv_c1_backward(self):
|
def test_conv_c1_backward(self):
|
||||||
class Repro(torch.nn.Module):
|
class Repro(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -211,7 +211,6 @@ class TestMatmulCuda(TestCase):
|
||||||
self.cublas_addmm(size, dtype, False, True)
|
self.cublas_addmm(size, dtype, False, True)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@skipIfRocm
|
|
||||||
def test_cublas_and_lt_reduced_precision_fp16_accumulate(self):
|
def test_cublas_and_lt_reduced_precision_fp16_accumulate(self):
|
||||||
orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation
|
orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation
|
||||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||||
|
|
@ -739,7 +738,6 @@ class TestMatmulCuda(TestCase):
|
||||||
|
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@skipIfRocm
|
|
||||||
@parametrize("batch_size", [1, 32])
|
@parametrize("batch_size", [1, 32])
|
||||||
@parametrize("backend", ["cublas", "cublaslt"])
|
@parametrize("backend", ["cublas", "cublaslt"])
|
||||||
def test_fp16_accum_and_fp32_out_failure(self, batch_size, backend):
|
def test_fp16_accum_and_fp32_out_failure(self, batch_size, backend):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
run_tests,
|
||||||
gradcheck,
|
gradcheck,
|
||||||
parametrize,
|
parametrize,
|
||||||
skipIfRocm,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -231,7 +230,6 @@ class TestSegmentReductions(TestCase):
|
||||||
length_type,
|
length_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
@dtypes(
|
@dtypes(
|
||||||
*product(
|
*product(
|
||||||
(torch.half, torch.bfloat16, torch.float, torch.double),
|
(torch.half, torch.bfloat16, torch.float, torch.double),
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from torch.testing import make_tensor, FileCheck
|
||||||
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
|
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
|
||||||
from torch.testing._internal.common_utils import \
|
from torch.testing._internal.common_utils import \
|
||||||
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase,
|
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase,
|
||||||
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm,
|
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo,
|
||||||
skipIfRocmVersionLessThan, IS_FBCODE, IS_REMOTE_GPU, suppress_warnings)
|
skipIfRocmVersionLessThan, IS_FBCODE, IS_REMOTE_GPU, suppress_warnings)
|
||||||
from torch.testing._internal.common_device_type import \
|
from torch.testing._internal.common_device_type import \
|
||||||
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
|
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
|
||||||
|
|
@ -3725,7 +3725,6 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||||
|
|
||||||
@parametrize("block_size", [16, 32, 64])
|
@parametrize("block_size", [16, 32, 64])
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@skipIfRocm
|
|
||||||
@dtypes(torch.half, torch.bfloat16, torch.float)
|
@dtypes(torch.half, torch.bfloat16, torch.float)
|
||||||
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
|
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
|
||||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
|
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user