From dd7cd182abfab48bad38a7beb5966dc9545e96da Mon Sep 17 00:00:00 2001 From: Saurabh Mishra Date: Wed, 4 Sep 2024 18:49:34 +0000 Subject: [PATCH] [AIInfra][DCP] All gather keys checkpoint utils bug fix (#135045) Summary: All gather keys checkpoint utils bug fix. Dist. get_world_size should have the process group passed in to avoid inconsistent world size in case the process group has changed. This is common in the tests. Test Plan: UTs Reviewed By: Saiteja64 Differential Revision: D61578832 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135045 Approved by: https://github.com/MeetVadakkanchery, https://github.com/LucasLLC --- torch/distributed/checkpoint/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index 5bc171a00dd..cb90dd11912 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -44,7 +44,7 @@ def _all_gather_keys( ) -> List[Any]: """Gathers all keys, and returns them sorted.""" keys = list(local_dict.keys()) - gathered_keys: List[List[Any]] = [None] * dist.get_world_size() # type: ignore[list-item] + gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item] dist.all_gather_object(gathered_keys, keys, group=group) return sorted(set(itertools.chain.from_iterable(gathered_keys)))