mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] fix failing test_dp_state_dict_save_load on ROCm CI where world_size=7 (#153283)
**Summary**
I saw an unrelated CI failure `distributed/_composable/fsdp/test_fully_shard_state_dict.py::TestFullyShardStateDictMultiProcess::test_dp_state_dict_save_load` in one of my PR: https://hud.pytorch.org/pr/pytorch/pytorch/153225#41930032096
This is caused by triggering uneven sharding in FSDP2 at cbb03e6971/torch/distributed/fsdp/_fully_shard/_fsdp_param.py (L353-L361)
This didn't show up because the cuda CI has even number of GPUs (e.g. 2/4/8) but it's not true on ROCm CI. For the failing CI case, the device number is 7.
**Solution**
Skip the test if `self.world_size` can not divide `mlp_dim` (i.e. 16).
**Test**
CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153283
Approved by: https://github.com/fegin, https://github.com/weifengpy
This commit is contained in:
parent
fc7d8c6808
commit
bc4cf1c13a
|
|
@ -39,8 +39,15 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
|
|||
{"mlp_dim": [2, 3, 4, 5], "mesh": [fsdp_mesh]},
|
||||
self._test_dp_state_dict_save_load,
|
||||
)
|
||||
if 16 % self.world_size == 0:
|
||||
# TODO: remove this evenness check when FSDP2 supports uneven sharding
|
||||
# see: https://github.com/pytorch/pytorch/blob/cbb03e69717943ddf912f9a68b3a6f935bbf21f5/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L353-L361 # noqa: B950
|
||||
self.run_subtests(
|
||||
{"mlp_dim": [16], "mesh": [fsdp_mesh], "use_shard_placement_fn": [True]},
|
||||
{
|
||||
"mlp_dim": [16],
|
||||
"mesh": [fsdp_mesh],
|
||||
"use_shard_placement_fn": [True],
|
||||
},
|
||||
self._test_dp_state_dict_save_load,
|
||||
)
|
||||
if self.world_size % 2 != 0:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user