diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index ea04cdc86b0..b73ad2b4937 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -5455,7 +5455,7 @@ def new_subgroups( ) subgroups.append(subgroup) - if rank := get_rank(group=group) in ranks_in_subgroup: + if rank := get_rank() in ranks_in_subgroup: cur_subgroup = subgroup logger.info("Rank %s is assigned to subgroup %s", rank, ranks_in_subgroup) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 797233a8558..71026e4b142 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -915,6 +915,40 @@ class DistributedTest: for subgroup in subgroups: dist.destroy_process_group(subgroup) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["subgroup"], + f"The {BACKEND} backend does not support creating subgroups on CUDA devices", + ) + @require_world_size(4) + @skip_if_lt_x_gpu(4) + def test_new_subgroups_with_group_param(self): + # Initialize global test environment + self._init_global_test() + # Set up GPU devices for each rank + init_multigpu_helper(dist.get_world_size(), BACKEND) + # Create two subgroups: one with ranks [0,2] and another with ranks [1,3] + cur_subgroup, subgroups = dist.new_subgroups_by_enumeration( + ranks_per_subgroup_list=[[0, 2], [1, 3]] + ) + + # Further divide the current subgroup into sub-subgroups of size 1 + cur_sub_subgroup, sub_subgroups = dist.new_subgroups( + group_size=1, group=cur_subgroup + ) + # Verify we have 2 sub-subgroups (one for each rank in the original subgroup) + self.assertEqual(len(sub_subgroups), 2) + # Verify the current process's sub-subgroup has size 1 + self.assertEqual(cur_sub_subgroup.size(), 1) + # Verify the current process is in its assigned sub-subgroup + self.assertFalse(dist._rank_not_in_group(group=cur_sub_subgroup)) + + # Clean up by destroying all created process groups + for sub_subgroup in sub_subgroups: + dist.destroy_process_group(sub_subgroup) + + for subgroup in subgroups: + dist.destroy_process_group(subgroup) + @skip_but_pass_in_sandcastle_if( BACKEND not in DistTestCases.backend_feature["subgroup"], f"The {BACKEND} backend does not support creating subgroups on CUDA devices",