## summary
`zip(inputs, self.input_layouts, self.desired_input_layouts)` is used in `_prepare_input_fn`; similar for `_prepare_output_fn`. Without assertion, unmatched dimension in inputs/outputs will be lost, potentially causing unexpected behabiors.
## test plan
`python test/distributed/tensor/parallel/test_tp_style.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115957
Approved by: https://github.com/wanchaol
some typo result in the note section not rendered properly, can't see
this from the last PR directly as the last PR only show the first commit
documentation :(
Also make the parallelize_module doc example more concrete
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115974
Approved by: https://github.com/wz337
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, add documentation.
We created stubs for public class and methods in torch.distributed.device_mesh so that torch.distributed.device_mesh can be imported with or without distributed is available().
Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/115099
Prior to landing, CI signals are all passed. Shipit added the "ci/trunk" label to the PR and DID NOT wait for it and went ahead committing. More context can be found in the reverted PR above.
Test Plan: CI.
Differential Revision: D51861018
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115193
Approved by: https://github.com/fegin
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.
Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/114991
It was failing because failing a public module binding tests in MacOS, and this is due to the change in import order for torch/distributed/fsdp/_common_utils.py. Since this original import would still work, we remove the changes in this file.
Test Plan: CI.
Differential Revision: D51825114
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115099
Approved by: https://github.com/wanchaol, https://github.com/fegin
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.
In particular this PR:
* Make ParallelStyle to be a real contract API for parallelize_module to
take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
output_layouts to desired_input/output_layouts, group
the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
style, all prepare input/output functions) as we throw deprecation
msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
documentation more clear about what each style is doing
TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes
Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
as titled, built on top of the work @wz337 enabled, this could save some
runtime CPU time to recreate DTensor parameters with correct
shape/stride, and avoid issues when un-even sharding parameters
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113547
Approved by: https://github.com/XilunWu
ghstack dependencies: #113323, #113324
Fixes: #113193
`pydocstyle <all_files_in_issue> --count`
- Before: 345
- After: 130
For deprecated methods, I have added a `noqa` to ignore them. I was not able to find the file `torch/distributed/tensor/parallel/multihead_attention_tp.py`, so I've ignored it for this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113241
Approved by: https://github.com/kit1980
We are refactoring parallel style to solve the following things:
1. To further simplifying code logic to make more readable for users.
2. To remove tuple check so that we can work with dynamo for now. Ideally dynamo needs to support this case and we will fix it in parallel.
3. Add tests for newly added parallel style in UT and torch compile test so that we can capture regression due to code change.
4. Move placements early return check into DTensor since it is by passed by dynamo.
5. Remove PairwiseParallelStyle from unit tests to use the new Col/Rowwise parallel style.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111625
Approved by: https://github.com/wanchaol
In some use cases, we found that users might want to annote the input/output DTensor layout for the parent module rather than the submodule whose parameters are to be distributed so that we want to have these two class for users to annote input/output DTensor layouts so that we register pre-FWD/FWD hook for the TP-lized module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111166
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160
One thing we find it challenging for users is that we don't want to expose the concept of prepare_input and prepare_out to users since there are so many func names for users to select from which is quite confusing. On the other hand, the colwise and rowwise parallel always need input(out) and output(in) to be certain layout so we can somehow simplify the logic here and make it more usable.
So we added three public attributes to the parallelStyle here and the code logic is like:
```python
class ParallelStyle(ABC):
"""
The parallel style user wants the module or submodule to be parallelized.
We can add more in future, but this seems sufficient for immediate needs. Users can extend this class to build their own parallel style with customized input/output preparations.
"""
input_layouts: Union[placement, Tuple[placement]]
output_layouts: Union[placement, Tuple[placement]]
use_local: bool
class RowwiseParallel(ParallelStyle):
"""
Partitioning the row of a module. We assume the input to be a sharded DTensor and output to be a replicate Tensor.
"""
def __init__(self):
super().__init__(input_layouts=Shard(-1), output_layouts=Replicate(), use_local=True)
Class ColwiseParallel(ParallelStyle):
"""
Partitioning the column of a module. We assume the input to be a Replicated DTensor and output to be a sharded DTensor.
"""
def __init__(self):
super().__init__(input_layouts=Replicate(), output_layouts=Shard(-1), use_local=True)
# For the case of Sequence parallel, users just set different input_shard, Shard(0) or Shard(1) instead of Replicate()
Class PrepareModuleInput(ParallelStyle):
"""
Only used to specify the input distribute spec for a module.
"""
def __init__(self):
super().__init__(input_layouts=Shard(0), output_layouts=Replicate(), use_local=False)
Class PrepareModuleOutput(ParallelStyle):
"""
Only used to specify the output distribute spec for a module.
"""
def __init__(self):
super().__init__(input_layouts=Replicate(), output_layouts=Shard(0), use_local=True)
parallelize_plan = {
"embedding": ColwiseParallel(output_shard=Replicate()),
"attn": PrepareModuleInput(),
"attn.w1": ColwiseParallel(),
"attn.w2": ColwiseParallel(),
"attn.w3": ColwiseParallel(),
"attn.wo": RowwiseParallel(),
}
parallelize_module(
module=block, # this can be a submodule or module
device_mesh=mesh['tp'],
parallelize_plan=parallelize_plan,
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111160
Approved by: https://github.com/wanchaol
Currently, we only support intranode TP when compositin TP with other parallelism. This PR adds additional check to validate the TP mesh dim in TP initialization when parent mesh exists.
cc. @fegin, @fduwjj
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111001
Approved by: https://github.com/fduwjj
This PR adds a all_gather_dtensor() method to fsdp/_fsdp_extensions.py and the actual implementation in tensor/parallel/fsdp.py. This enables FSDP to load 2D DTensor state_dict into model when calling `model.load_state_dict()`.
cc. @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110925
Approved by: https://github.com/fegin
ghstack dependencies: #110831, #110846
Replacing https://github.com/pytorch/pytorch/pull/109553 as it gets reverted.
This PR enables training with new 2D flow and adds associated test. In addition, this PR moves the tensor/parallel/_data_parallel_utils.py that are fsdp specific back to tensor/parallel/fsdp.py to avoid circular dependency for ddp.py and test/distributed/tensor/parallel/test_ddp_2d_parallel.py.
state_dict related changes would be in later PRs.
cc. @fegin, @fduwjj, @wanchaol, @awgu
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110034
Approved by: https://github.com/fduwjj
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.
This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms

after (with this change), aten.addmm latency: 0.341ms

overall one layer of mlp time reduced from 13.535 -> 9.665ms
Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007