pytorch/torch/nn/parallel
wz337 87053132ea [DeviceMesh] Remove parent mesh concept from _MeshEnv and replace by root mesh (#132339)
Previously, when we slice out a submesh from a mesh, we assign the mesh as the parent mesh of the submesh. In this case, when we have a 3D mesh topology, the parent mesh of a 1D mesh sliced out from the 3D mesh is different from the parent mesh of the same 1D mesh sliced out from the 2D submesh of the 3D mesh. For example:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]

mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 =  mesh_2d["dim0_2"]

# This would evaluate to be True
print(_mesh_resources.get_parent_mesh(mesh_dim0) != _mesh_resources.get_parent_mesh(mesh_dim0))
```

We can always reconstruct the mesh needed from the mesh dim names, as long as two dims come from the same root. For simplicity, we do not see the necessity of building a tree structure to represent child-parent relationship. Therefore, we are replacing the parent mesh concept with a root mesh concept in `_MeshEnv` so we would have:

```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]

mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 =  mesh_2d["dim0_2"]

# This would evaluate to be True
print(_mesh_resources.get_root_mesh(mesh_dim0) == _mesh_resources.get_root_mesh(mesh_dim0))
```
With this change, we will have two types of meshes in an environment.
1. `device_mesh != _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is created by slicing.
2. `device_mesh == _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is a root mesh not created through slicing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132339
Approved by: https://github.com/wanchaol
ghstack dependencies: #132310, #132311
2024-08-07 07:01:12 +00:00
..
__init__.py [BE][Easy] enable UFMT for torch/nn/ (#128865) 2024-07-25 02:48:42 +00:00
_functions.py [BE][Easy] enable UFMT for torch/nn/ (#128865) 2024-07-25 02:48:42 +00:00
comm.py [BE][Easy] enable UFMT for torch/nn/parallel (#128596) 2024-06-17 16:29:22 +00:00
data_parallel.py [BE][Easy] enable UFMT for torch/nn/ (#128865) 2024-07-25 02:48:42 +00:00
distributed.py [DeviceMesh] Remove parent mesh concept from _MeshEnv and replace by root mesh (#132339) 2024-08-07 07:01:12 +00:00
parallel_apply.py [BE][Easy] enable UFMT for torch/nn/ (#128865) 2024-07-25 02:48:42 +00:00
replicate.py [BE][Easy] enable UFMT for torch/nn/ (#128865) 2024-07-25 02:48:42 +00:00
scatter_gather.py [BE][Easy] enable UFMT for torch/nn/ (#128865) 2024-07-25 02:48:42 +00:00