mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DeviceMesh] Clarifying flatten use case (#161311)
Since we are in the middle of big refactoring and simplying the bookkeeping for device mesh. We found an interesting bug inside DeviceMesh flatten implementation. Here is the finding: 1. In unit test, we assume users can call `dp_cp_mesh._flatten()` many times but no backend will be created (aka cached). 2. From the implementation of slicing, we actually throw exception erroring out doing the `_flatten` more than once. But there is bug which was partially fixed in https://github.com/pytorch/pytorch/pull/160709 but it does not fixed the check for the case when we call the `_flatten` twice. What's more important question to ask is, what behavior we want for `_flatten`? Do we allow calling `_flatten` multiple times (with same mesh_name)? I think we should, why? 1. We allow slicing for the same mesh_name or name_list multiple times, and we cache the PG behinds. Although we will return a new device mesh object everytime, when we compare them they are all the same (according to __eq__). 2. We actually cached the flattened mesh today inside `root_to_flatten_mapping` and actually do the early return but that line will never be reached if we error out before that. Also we should allow a no-op for flatten a 1D mesh into itself's mesh_dim_name, I added a unit test for it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161311 Approved by: https://github.com/fegin
This commit is contained in:
parent
b2d8f6a6af
commit
be8095b07f
|
|
@ -825,6 +825,15 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
|||
):
|
||||
mesh_3d["cp", "dp"]
|
||||
|
||||
@with_comms
|
||||
def test_flatten_mesh_1d(self):
|
||||
mesh_shape = (4,)
|
||||
mesh_dim_names = ("default",)
|
||||
mesh_1d = init_device_mesh(
|
||||
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
mesh_1d._flatten()
|
||||
|
||||
@with_comms
|
||||
def test_flatten_mesh_3d(self):
|
||||
mesh_shape = (2, 2, 2)
|
||||
|
|
@ -833,6 +842,13 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
|||
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
|
||||
# Test flatten into an existing mesh_dim_name inside the mesh
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"already exists for submesh of the DeviceMesh",
|
||||
):
|
||||
mesh_3d._flatten("dp")
|
||||
|
||||
# Test flatten contiguous dims
|
||||
dp_cp_mesh = mesh_3d["dp", "cp"]
|
||||
flattened_dp_cp_mesh = dp_cp_mesh._flatten()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import threading
|
|||
import warnings
|
||||
from collections.abc import Iterator
|
||||
from functools import reduce
|
||||
from itertools import chain, zip_longest
|
||||
from itertools import zip_longest
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -185,12 +185,15 @@ else:
|
|||
if not mesh_dim_name:
|
||||
mesh_dim_name = "_".join(not_none(device_mesh.mesh_dim_names))
|
||||
|
||||
# Flatten a 1D device mesh into its original mesh_dim_name will return itself.
|
||||
if device_mesh.ndim == 1 and mesh_dim_name in not_none(
|
||||
device_mesh.mesh_dim_names
|
||||
):
|
||||
return device_mesh
|
||||
|
||||
# Check whether the mesh_dim_name for flattened mesh is valid.
|
||||
self.flatten_name_to_root_dims.setdefault(root_mesh, {})
|
||||
invalid_dim_names = chain(
|
||||
list(not_none(root_mesh.mesh_dim_names)),
|
||||
*self.flatten_name_to_root_dims[root_mesh].keys(),
|
||||
)
|
||||
invalid_dim_names = not_none(root_mesh.mesh_dim_names)
|
||||
if mesh_dim_name in invalid_dim_names:
|
||||
raise RuntimeError(
|
||||
f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ",
|
||||
|
|
@ -199,8 +202,6 @@ else:
|
|||
)
|
||||
|
||||
# Quick return if the flatten mesh has been created before.
|
||||
# TODO: If we decide to restrict flatten initialization once, we should remove
|
||||
# this check and throw an error if the flatten mesh is already created before.
|
||||
if (
|
||||
root_mesh in self.root_to_flatten_mapping
|
||||
and mesh_dim_name in self.root_to_flatten_mapping[root_mesh]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user