Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
Previously, when we slice out a submesh from a mesh, we assign the mesh as the parent mesh of the submesh. In this case, when we have a 3D mesh topology, the parent mesh of a 1D mesh sliced out from the 3D mesh is different from the parent mesh of the same 1D mesh sliced out from the 2D submesh of the 3D mesh. For example:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]
mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 = mesh_2d["dim0_2"]
# This would evaluate to be True
print(_mesh_resources.get_parent_mesh(mesh_dim0) != _mesh_resources.get_parent_mesh(mesh_dim0))
```
We can always reconstruct the mesh needed from the mesh dim names, as long as two dims come from the same root. For simplicity, we do not see the necessity of building a tree structure to represent child-parent relationship. Therefore, we are replacing the parent mesh concept with a root mesh concept in `_MeshEnv` so we would have:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]
mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 = mesh_2d["dim0_2"]
# This would evaluate to be True
print(_mesh_resources.get_root_mesh(mesh_dim0) == _mesh_resources.get_root_mesh(mesh_dim0))
```
With this change, we will have two types of meshes in an environment.
1. `device_mesh != _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is created by slicing.
2. `device_mesh == _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is a root mesh not created through slicing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132339
Approved by: https://github.com/wanchaol
ghstack dependencies: #132310, #132311
Previously, using _MaskPartial when multiple embeddings have the following issues:
1. Suppose an `nn.Embedding` has shape `[vocab_size, emb_size]`. When there are more than one embeddings, sharing the same `vocab_size` but with different `emb_size`s. Then they would not share `OpStrategy` since each, when involved in computation, would have different `OpSchema`; however, there would be cache hit for redistribute (specifically `_gen_transform_infos` in `torch/distributed/_tensor/_redistribute.py` when doing `Replicate` -> `_MaskPartial`) as the `_MaskPartial` only has `vocab_size` as `logical_dim_size` but not `emb_size` as attribute. This cache hit is undesirable and would cause trouble when doing all-reduce/reduce-scatter on the new `_MaskPartial` in a separate `OpStrategy`. The error was reported in #130725. In this PR, we introduce `offset_shape` to represent the embedding's full shape to avoid cache hit from embeddings of different shapes.
2. The second issue is when we have two `nn.Embedding`s `emb1` and `emb2` with the same shape. There will be cache hit not only in `_gen_transform_infos`, but also in `OpStrategy` generation. Previously, if we sequentially do `Replicate` -> `_MaskPartial` for both `emb1` `emb2` and then sequentially do reduction on the `_MaskPartial` of `emb1`, it would destroy the `MaskBuffer` and `emb2` would hit error. This PR adds a `refcount` for the `MaskBuffer` so that it can be properly shared by multiple `nn.Embedding`s.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131264
Approved by: https://github.com/wanchaol
SequenceParallel style assumes the input torch.Tensor ALREADY sharded on
the sequence dimension if not passing in DTensor. Since it causes some
user confusion on the documentation, this PR:
1. for the case where input passed in is already a DTensor, we check the
input placements and redistribute if it's not sharded on the sequence
dimension
2. update the doc to make it more explicit about the case when user
passed in a torch.Tensor and DTensor
This would fix https://github.com/pytorch/pytorch/issues/129355
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131346
Approved by: https://github.com/awgu
as titled, given that our DTensorSpec is immutable, we can always reuse
the spec if the input/output have the same tensor metadata. this helps two fold:
1. We don't need to re-calculate the hash everytime we produce a
DTensorSpec, reduce runtime operator overhead
2. reduce the DTensor construction overhead.
Some local benchmark on a 800 parameter clip_grad_norm shows that for
foreach_norm the CPU overhead reduces from 11ms -> 7.8ms (around 30% improvement)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128112
Approved by: https://github.com/awgu
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves#126888
- #126888
This PR is split from PR #126898.
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
`from_local` with replicate placement would run mesh_broadcast if `run_check=True`, by default `from_local` have `run_check=True`, but for FSDP state_dict case we are for sure that these are replicated on dp dimension (FSDP + TP) already, so we don't need to check/force check it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123802
Approved by: https://github.com/wanchaol
Adding wildcard support for TP's `parallelize_module` API.
Example patterns:
`layers.*.linear`: any characters
`layers.?.linear`: single character
`layers.[1-2]`: digit range, matches `layers.1` and `layers.2`
Example use case:
A model have multiple layers, and we want to parallelize the linear module `lin` inside each layer.
```
model_tp = parallelize_module(
model,
device_mesh,
{
"layers.*.lin": ColwiseParallel(),
},
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122968
Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/wanchaol
ghstack dependencies: #122919
This PR fixed the bug of redistribute to move early return check into the
redistribute autograd function, so that even though we redistribute the
same placement, the grad_placements from the `to_local` call might be
different, the redistribute backward still need to happen
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121653
Approved by: https://github.com/awgu
async output option was only available in `full_tensor()` call, but I think it's
generally good to make this option available in the `redistribute` call directly
so that user can control it
This PR adds async_op option to redistribute call, to allow user control
whether to perform tensor redistribution asynchronously or not.
By default we set this to False, this is to follow the semantics of the c10d
collectives.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121477
Approved by: https://github.com/wz337
This PR removes the deprecated tp_mesh_dim arg to prepare for release.
As we deprecated this arg for a while (by throwing deprecating
messages), we should remove it before the release
#suppress-api-compatibility-check
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121432
Approved by: https://github.com/wz337
ghstack dependencies: #121431
Since CommDebugMode is fixed, we can check that loss parallel is working as expected.
Under loss parallel, the forward computation should invoke 3 all-reduces, and the backward computation should invoke no functional collectives.
Co-authored-by: Wanchao <wanchaol@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121366
Approved by: https://github.com/wanchaol
As titled, this PR introduces a dedicated `ParallelStyle` to shard the
nn.LayerNorm/nn.Dropout/RMSNorm layers. We were mainly using a manual
distribute_module calls before when sharding the RMSNorm layer, but I
think we should have a dedicate TP API to easily shard those layers,
instead of user manually using DTensors.
I call this SequenceParallel, which might bring some confusion that we
technically "deprecated" a SequenceParallel style months ago. But this
time the SeuqenceParallel style is significantly different with the
previous ones (which used to shard two consecutive Linear layers). I
believe making it the right name is the first priority, instead of
worrying about the issue of reusing the old name
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121295
Approved by: https://github.com/awgu, https://github.com/tianyu-l
ghstack dependencies: #121294
This is a BC breaking change to distribute_module. The underlying rationle
for this change is that sometimes in the input_fn/output_fn, user would want
to access to the current module for some attributes. This might not be
common enough, but in some cases it's worth to access to the module.
An outstanding use case we want to support is float8, if we want to make
float8 works with the TP API, the input_fn/output_fn of TP parallel
styles would need to get access to the module, where the module might
encapsulates `dynamic_linear.emulate` attribute, that is useful for
input/output casting
Since this is needed for fp8 and DTensor still under prototype release,
I feel it's worth the change and it's better we make the change as
early.
Right now making it a soft BC breaking, which means we maintain BC still
but throw deprecation messages.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120895
Approved by: https://github.com/tianyu-l
Loss parallel is the last piece of sequence parallelism to enable. It enables efficient distributed cross entropy computation when the input is sharded on the class dimension (in a classification problem with many classes). The implementation is via a context manager `loss_parallel`, after enabling which users can directly use `torch.nn.functional.cross_entropy` or `torch.nn.CrossEntropyLoss` without modifying other parts of their code.
Here are the underlying rationales why we are going through these op replacements:
1. `nn.functional.cross_entropy` is the common method that OSS user is using for things like transformer training, to avoid changing user code, we want user to still use this function for loss calculation if they are already using it.
2. `nn.functional.cross_entropy` boils down into `aten.log_softmax` and `aten.nll_loss_foward/backward`, and DTensor now supports those ops already (#117723#119255#118917#119256). They are doing computation with input *replicated* on the class dimension.
3. However when the input of this loss calculation is **sharded on the class dimension**, to run sharded computation efficiently, we need to run both `aten.log_softmax` and `aten.nll_loss_foward` with multiple all-reduce collectives **in the middle of** those aten ops. This is not possible if we are just overriding these two ops, so we need to have some way to **decompose** these two ops into smaller ops to have collectives run in the middle of these two ops.
4. We explored the existing decompositions (#118950). It seems working, except that `log_softmax_backward` and `nll_loss_backward` combined together in aten are implemented in a inefficient way, which would trigger an additional expensive collective. Recently some user also reported similar issues https://github.com/pytorch/pytorch/issues/119261.
5. Therefore, currently we are doing our own decomposition inside a context manager for sequence parallelism specifically. Once we have a better decomposition in core, we can possibly take that instead of reinventing the wheels here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119877
Approved by: https://github.com/wanchaol
Update all_gather to support HSDP + TP.
Currently, the `_all_gather_dtensor` function for dtensors only replaces the first dimension with replicate (the FSDP dimension) and does not touch the second dimension (which is assumed to be the TP dimension). With HSDP, we have two dimensions ahead of the TP dimension as opposed to 1. This PR updates to replace all other dimensions with replicate to run the all-gather.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118638
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wz337
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Currently, we create new_group for sub_group pg during mesh initialization. The PR changes this so we will:
1) re-use sub_group pg if it exsits,
2) create new sub_group pg if it does not exist.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115716
Approved by: https://github.com/wanchaol