mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AIInfra][DCP] All gather keys checkpoint utils bug fix (#135045)
Summary: All gather keys checkpoint utils bug fix. Dist. get_world_size should have the process group passed in to avoid inconsistent world size in case the process group has changed. This is common in the tests. Test Plan: UTs Reviewed By: Saiteja64 Differential Revision: D61578832 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135045 Approved by: https://github.com/MeetVadakkanchery, https://github.com/LucasLLC
This commit is contained in:
parent
eb0fd17bc4
commit
dd7cd182ab
|
|
@ -44,7 +44,7 @@ def _all_gather_keys(
|
|||
) -> List[Any]:
|
||||
"""Gathers all keys, and returns them sorted."""
|
||||
keys = list(local_dict.keys())
|
||||
gathered_keys: List[List[Any]] = [None] * dist.get_world_size() # type: ignore[list-item]
|
||||
gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
|
||||
|
||||
dist.all_gather_object(gathered_keys, keys, group=group)
|
||||
return sorted(set(itertools.chain.from_iterable(gathered_keys)))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user