[ROCm] Enable several fsdp related UTs (#149369)

Enabling 26 UTs for ROCm in the following files:

-  distributed._shard.sharded_optim.test_sharded_optim - 2 UTs
-  distributed._shard.sharded_tensor.ops.test_binary_cmp - 4 UTs
-  distributed._shard.sharded_tensor.ops.test_init - 3 UTs
-  distributed._shard.sharded_tensor.ops.test_embedding - 2 UTs
-  distributed._shard.sharded_tensor.ops.test_embedding_bag - 2 UTs
-  distributed._composable.test_replicate_with_compiler - 4 UTs
-  distributed._composable.fsdp.test_fully_shard_grad_scaler - 1 UTs
-  distributed.tensor.test_attention - 4 UTs
-  distributed.tensor.test_matrix_ops - 1 UTs
-  distributed.tensor.test_tensor_ops - 1 UTs
-  distributed.fsdp.test_fsdp_grad_acc - 2 UTs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149369
Approved by: https://github.com/jeffdaily
This commit is contained in:
Prachi Gupta 2025-03-31 16:15:57 +00:00 committed by PyTorch MergeBot
parent 7c858066ae
commit 47cdad2995
8 changed files with 31 additions and 42 deletions

View File

@ -131,7 +131,6 @@ class TestFullyShardCompile(FSDPTest):
if not sm_is_or_higher_than(device, 8, 0):
self.skipTest("bf16 requires sm >= 8.0")
@skipIfRocm
def test_dynamo_trace_use_training_state(self):
torch._dynamo.reset()
# Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager.
@ -169,7 +168,6 @@ class TestFullyShardCompile(FSDPTest):
self.assertEqual(cnt.op_count, 1)
self.assertEqual(len(cnt.graphs), 1)
@skipIfRocm
def test_trace_fsdp_copy_(self):
@torch.library.custom_op("mylib::add_one_out", mutates_args={"out"})
def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None:

View File

@ -13,12 +13,11 @@ from torch.distributed.tensor.parallel import (
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, MLP
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.common_utils import run_tests
class TestFullyShardGradientScaler(FSDPTest):
@skip_if_lt_x_gpu(4)
@skipIfRocm
def test_gradient_scaler(self):
self.run_subtests(
{"has_inf": [True, False], "test_2d": [True, False]},

View File

@ -28,11 +28,10 @@ from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import (
DistributedTestBase,
skip_if_lt_x_gpu,
skip_if_rocm_multiprocess,
sm_is_or_higher_than,
)
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils.checkpoint import checkpoint
@ -194,7 +193,6 @@ class ReplicateTest(MultiProcessInductorTestCase):
self._test_compile(no_sync=True, device="cpu")
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
@torch._inductor.config.patch(
reorder_for_locality=False, reorder_for_peak_memory=False
@ -203,7 +201,6 @@ class ReplicateTest(MultiProcessInductorTestCase):
self._test_compile(no_sync=False, checkpoint=False, device=device_type)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
@torch._inductor.config.patch(
reorder_for_locality=False, reorder_for_peak_memory=False
@ -212,11 +209,13 @@ class ReplicateTest(MultiProcessInductorTestCase):
self._test_compile(no_sync=False, checkpoint=True, device=device_type)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_compile_bf16(self):
# Check device capability wrt bf16
if not sm_is_or_higher_than(torch.device(device_type), 8, 0):
if (
not sm_is_or_higher_than(torch.device(device_type), 8, 0)
and torch.version.hip is None
):
self.skipTest("bf16 requires sm >= 8.0")
def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
@ -230,7 +229,6 @@ class ReplicateTest(MultiProcessInductorTestCase):
self._test_compile(no_sync=False, setup_func=setup, device=device_type)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_compile_fp16(self):
def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
@ -247,7 +245,6 @@ class ReplicateTest(MultiProcessInductorTestCase):
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_compile_backward_only(self):
self._test_compile(no_sync=False, no_compile_forward=True, device=device_type)
@ -387,7 +384,6 @@ class DDP_TP_Test(InductorTestCase):
"Temporarily disabled due to SymInt error: `unhashable type: non-nested SymInt`"
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skipIfRocm
def test_ddp_tp(self):
ref_model = Net()
compiled_replicate_model = deepcopy(ref_model)

View File

@ -24,7 +24,6 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfRocm,
TEST_WITH_DEV_DBG_ASAN,
)
@ -275,7 +274,6 @@ class TestGradAcc(FSDPTest):
)
@skip_if_lt_x_gpu(2)
@skipIfRocm
@parametrize("use_orig_params", [False, True])
def test_grad_acc_cpu_offload(
self,

View File

@ -18,7 +18,8 @@ from torch.distributed.tensor import (
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.common_device_type import E4M3_MAX_POS, e4m3_type
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_unless_torch_gpu,
@ -33,8 +34,10 @@ def scale_for_fp8(
t = t.unsqueeze(0).unsqueeze(-2)
else:
t = t.unflatten(0, (scale_shape[0], -1)).unflatten(-1, (scale_shape[1], -1))
scale = t.abs().amax(dim=[1, -1]).float() / torch.finfo(torch.float8_e4m3fn).max
t_fp8 = (t / scale[:, None, :, None]).to(torch.float8_e4m3fn)
scale = t.abs().amax(dim=[1, -1]).float() / E4M3_MAX_POS
t_fp8 = (t / scale[:, None, :, None]).to(e4m3_type)
return t_fp8.flatten(end_dim=1).flatten(start_dim=-2), scale.view(scale_shape)
@ -205,7 +208,7 @@ class DistMatrixOpsTest(DTensorTestBase):
full_dist_res = dist_res.full_tensor()
# Fp8 matmuls are quite inaccurate, we need high tolerances
self.assertEqual(full_dist_res, full_ref_res, atol=1, rtol=7e-2)
self.assertEqual(full_dist_res, full_ref_res, atol=1.5, rtol=7e-2)
self.assertEqual(comm_mode.get_total_counts(), 0)
@ -448,7 +451,6 @@ class DistMatrixOpsTest(DTensorTestBase):
self.assertTrue(dist_value.grad.placements[0].is_shard(dim=1))
self.assertEqual(dist_value.grad.full_tensor(), value.grad)
@skipIfRocm
@skip_unless_torch_gpu
@with_comms()
def test_dtensor_mm(self):
@ -472,7 +474,9 @@ class DistMatrixOpsTest(DTensorTestBase):
lhs_dtensor = distribute_tensor(lhs, mesh, [Shard(dim=0), Replicate()])
rhs_dtensor = distribute_tensor(rhs, mesh, [Replicate(), Shard(dim=1)])
dtensor_result = lhs_dtensor @ rhs_dtensor
self.assertEqual(dtensor_result.full_tensor(), mm_result)
self.assertEqual(
dtensor_result.full_tensor(), mm_result, atol=1.5e-5, rtol=1e-6
)
@with_comms
@skip_unless_torch_gpu

View File

@ -171,19 +171,10 @@ ROCM_BLOCKLIST = [
"distributed/rpc/test_tensorpipe_agent",
"distributed/rpc/test_share_memory",
"distributed/rpc/cuda/test_tensorpipe_agent",
"distributed/_shard/checkpoint/test_checkpoint"
"distributed/_shard/checkpoint/test_file_system_checkpoint"
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_tensor/ops/test_embedding",
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
"distributed/_shard/sharded_tensor/ops/test_init",
"distributed/_shard/sharded_optim/test_sharded_optim",
"test_determination",
"test_jit_legacy",
"test_cuda_nvml_based_avail",
"test_jit_cuda_fuser",
"distributed/tensor/test_attention",
]
S390X_BLOCKLIST = [

View File

@ -24,7 +24,7 @@ from torch.testing._internal.common_cuda import (
SM90OrLater,
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MX_GEMM
PLATFORM_SUPPORTS_MX_GEMM,
)
from torch.testing._internal.common_device_type import (
dtypes,
@ -32,6 +32,10 @@ from torch.testing._internal.common_device_type import (
onlyCUDA,
tol as xtol,
toleranceOverride,
e4m3_type,
e5m2_type,
E4M3_MAX_POS,
E5M2_MAX_POS,
)
from torch.testing._internal.common_utils import (
@ -258,17 +262,6 @@ class TestMatmulCuda(TestCase):
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
if torch.version.hip and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName:
e4m3_type = torch.float8_e4m3fnuz
e5m2_type = torch.float8_e5m2fnuz
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
else:
e4m3_type = torch.float8_e4m3fn
e5m2_type = torch.float8_e5m2
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
# avoid division by zero when calculating scale
EPS = 1e-12

View File

@ -1982,3 +1982,13 @@ flex_attention_supported_platform = unittest.skipUnless(
and torch.cuda.get_device_capability() >= (8, 0),
"Requires CUDA and Triton",
)
if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName:
e4m3_type = torch.float8_e4m3fnuz
e5m2_type = torch.float8_e5m2fnuz
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
else:
e4m3_type = torch.float8_e4m3fn
e5m2_type = torch.float8_e5m2
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max