[AIInfra][DCP] All gather keys checkpoint utils bug fix (#135045)

Summary: All gather keys checkpoint utils bug fix. Dist. get_world_size should have the process group passed in to avoid inconsistent world size in case the process group has changed. This is common in the tests.

Test Plan: UTs

Reviewed By: Saiteja64

Differential Revision: D61578832

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135045
Approved by: https://github.com/MeetVadakkanchery, https://github.com/LucasLLC
This commit is contained in:
Saurabh Mishra 2024-09-04 18:49:34 +00:00 committed by PyTorch MergeBot
parent eb0fd17bc4
commit dd7cd182ab

View File

@ -44,7 +44,7 @@ def _all_gather_keys(
) -> List[Any]:
"""Gathers all keys, and returns them sorted."""
keys = list(local_dict.keys())
gathered_keys: List[List[Any]] = [None] * dist.get_world_size() # type: ignore[list-item]
gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
dist.all_gather_object(gathered_keys, keys, group=group)
return sorted(set(itertools.chain.from_iterable(gathered_keys)))