[BE] fix failing test_dp_state_dict_save_load on ROCm CI where world_size=7 (#153283)

**Summary**
I saw an unrelated CI failure `distributed/_composable/fsdp/test_fully_shard_state_dict.py::TestFullyShardStateDictMultiProcess::test_dp_state_dict_save_load` in one of my PR: https://hud.pytorch.org/pr/pytorch/pytorch/153225#41930032096

This is caused by triggering uneven sharding in FSDP2 at cbb03e6971/torch/distributed/fsdp/_fully_shard/_fsdp_param.py (L353-L361)

This didn't show up because the cuda CI has even number of GPUs (e.g. 2/4/8) but it's not true on ROCm CI. For the failing CI case, the device number is 7.

**Solution**
Skip the test if `self.world_size` can not divide `mlp_dim` (i.e. 16).

**Test**
CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153283
Approved by: https://github.com/fegin, https://github.com/weifengpy
This commit is contained in:
Xilun Wu 2025-05-09 16:36:00 -07:00 committed by PyTorch MergeBot
parent fc7d8c6808
commit bc4cf1c13a

View File

@ -39,10 +39,17 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
{"mlp_dim": [2, 3, 4, 5], "mesh": [fsdp_mesh]},
self._test_dp_state_dict_save_load,
)
self.run_subtests(
{"mlp_dim": [16], "mesh": [fsdp_mesh], "use_shard_placement_fn": [True]},
self._test_dp_state_dict_save_load,
)
if 16 % self.world_size == 0:
# TODO: remove this evenness check when FSDP2 supports uneven sharding
# see: https://github.com/pytorch/pytorch/blob/cbb03e69717943ddf912f9a68b3a6f935bbf21f5/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L353-L361 # noqa: B950
self.run_subtests(
{
"mlp_dim": [16],
"mesh": [fsdp_mesh],
"use_shard_placement_fn": [True],
},
self._test_dp_state_dict_save_load,
)
if self.world_size % 2 != 0:
return
hsdp_mesh = init_device_mesh(