[DeviceMesh] Clarifying flatten use case (#161311)

Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding:
1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached).
2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in https://github.com/pytorch/pytorch/pull/160709 but it does not fixed the check for the case when we call the `_flatten` twice.

What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why?
1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__).
2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that  line will never be reached if we error out before that.

Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161311
Approved by: https://github.com/fegin
This commit is contained in:
fduwjj 2025-09-09 22:00:28 -07:00 committed by PyTorch MergeBot
parent b2d8f6a6af
commit be8095b07f
2 changed files with 24 additions and 7 deletions

View File

@ -825,6 +825,15 @@ class TestDeviceMeshGetItem(DTensorTestBase):
):
mesh_3d["cp", "dp"]
@with_comms
def test_flatten_mesh_1d(self):
mesh_shape = (4,)
mesh_dim_names = ("default",)
mesh_1d = init_device_mesh(
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
mesh_1d._flatten()
@with_comms
def test_flatten_mesh_3d(self):
mesh_shape = (2, 2, 2)
@ -833,6 +842,13 @@ class TestDeviceMeshGetItem(DTensorTestBase):
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
# Test flatten into an existing mesh_dim_name inside the mesh
with self.assertRaisesRegex(
RuntimeError,
"already exists for submesh of the DeviceMesh",
):
mesh_3d._flatten("dp")
# Test flatten contiguous dims
dp_cp_mesh = mesh_3d["dp", "cp"]
flattened_dp_cp_mesh = dp_cp_mesh._flatten()

View File

@ -7,7 +7,7 @@ import threading
import warnings
from collections.abc import Iterator
from functools import reduce
from itertools import chain, zip_longest
from itertools import zip_longest
from typing import Optional, TYPE_CHECKING, Union
import torch
@ -185,12 +185,15 @@ else:
if not mesh_dim_name:
mesh_dim_name = "_".join(not_none(device_mesh.mesh_dim_names))
# Flatten a 1D device mesh into its original mesh_dim_name will return itself.
if device_mesh.ndim == 1 and mesh_dim_name in not_none(
device_mesh.mesh_dim_names
):
return device_mesh
# Check whether the mesh_dim_name for flattened mesh is valid.
self.flatten_name_to_root_dims.setdefault(root_mesh, {})
invalid_dim_names = chain(
list(not_none(root_mesh.mesh_dim_names)),
*self.flatten_name_to_root_dims[root_mesh].keys(),
)
invalid_dim_names = not_none(root_mesh.mesh_dim_names)
if mesh_dim_name in invalid_dim_names:
raise RuntimeError(
f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ",
@ -199,8 +202,6 @@ else:
)
# Quick return if the flatten mesh has been created before.
# TODO: If we decide to restrict flatten initialization once, we should remove
# this check and throw an error if the flatten mesh is already created before.
if (
root_mesh in self.root_to_flatten_mapping
and mesh_dim_name in self.root_to_flatten_mapping[root_mesh]