diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index 5bc171a00dd..cb90dd11912 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -44,7 +44,7 @@ def _all_gather_keys( ) -> List[Any]: """Gathers all keys, and returns them sorted.""" keys = list(local_dict.keys()) - gathered_keys: List[List[Any]] = [None] * dist.get_world_size() # type: ignore[list-item] + gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item] dist.all_gather_object(gathered_keys, keys, group=group) return sorted(set(itertools.chain.from_iterable(gathered_keys)))