Commit Graph

12 Commits

Author SHA1 Message Date
Kumar Ashutosh
405a0040cf Adds tool to visualize sharding (#114307)
This pull request adds a tool to visualize sharding. It uses the device_mesh and placement details to construct a visualization of the split of a torch dtensor.

Things to fix:

- [x] This implementation only uses the first element of the placement tuple, when can there be more than one elements?
- [x] The calculation of the split is happening here but maybe it is already done somewhere internally in Shard class and can we directly call that here?

Fixes #108746

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114307
Approved by: https://github.com/wanchaol
2023-12-12 06:18:03 +00:00
Wanchao Liang
28925902fa [TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.

In particular this PR:

* Make ParallelStyle to be a real contract API for parallelize_module to
  take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
  refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
  both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
  output_layouts to desired_input/output_layouts, group
  the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
  standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
  style, all prepare input/output functions) as we throw deprecation
 msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
  mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
  documentation more clear about what each style is doing

TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes

Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 08:18:12 +00:00
KingsleyLiu-NV
cd2798943d [dtensor] support convolution ops (#113123)
This PR creates a prototype of training convolutional neural networks based on DTensor.

- Register required ops and implement operator dispatch
- Add unit tests and example

Basically, we shard the activations and replicate the model weights in this prototype. We can scale out to multiple GPUs and reduce the per-GPU memory footprint with this approach, and achieve weak scaling in terms of training performance (i.e., time per iteration).

Reference log (on 2xA100 GPU):

Unit Test
```bash
root@luna-prod-78-80gb:/pytorch# python3 test/distributed/_tensor/test_convolution_ops.py
/opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (Triggered internally at /opt/conda/conda-bld/pytorch_1699257304556/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2170.)
  return F.conv2d(input, weight, bias, self.stride,
/opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (Triggered internally at /opt/conda/conda-bld/pytorch_1699257304556/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2170.)
  return F.conv2d(input, weight, bias, self.stride,
..
----------------------------------------------------------------------
Ran 2 tests in 30.354s

OK
root@luna-prod-78-80gb:/pytorch# python3 test/distributed/_tensor/test_other_ops.py
[rank0]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank0]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank1]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank1]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
...
----------------------------------------------------------------------
Ran 3 tests in 16.343s

OK
```
ConvNeXt Example
```bash
root@luna-prod-78-80gb:/pytorch# python3 torch/distributed/_tensor/examples/convnext_example.py
rank 3, 20 iterations, latency     584.80 ms, forward     102.84 ms, backward     297.80 ms, max reserved    16.34 GiB, max allocated    14.75 GiB
rank 1, 20 iterations, latency     584.64 ms, forward     104.85 ms, backward     297.60 ms, max reserved    16.40 GiB, max allocated    14.74 GiB
rank 0, 20 iterations, latency     584.48 ms, forward     104.64 ms, backward     297.90 ms, max reserved    16.39 GiB, max allocated    14.75 GiB
rank 2, 20 iterations, latency     584.96 ms, forward      93.21 ms, backward     297.95 ms, max reserved    16.40 GiB, max allocated    14.74 GiB
```

@wanchaol @fduwjj FYI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113123
Approved by: https://github.com/wanchaol
2023-11-20 21:01:28 +00:00
Justin Chu
232b96b6e2 [BE] Enable ruff's UP rules and autoformat distributed/ (#105433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105433
Approved by: https://github.com/albanD
2023-07-19 14:27:11 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
Shen Li
02179827cb [Easy] Include SPMD and DTensor files in UFMT checks (#98148)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98148
Approved by: https://github.com/fegin
2023-04-02 15:34:49 +00:00
Wanchao Liang
7a772bfff9 [dtensor] add submesh example to checkpoint_example (#95655)
This PR adds a submesh example for checkpoing purposes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95655
Approved by: https://github.com/XilunWu
2023-03-01 08:19:27 +00:00
Wanchao Liang
ee0e7f0529 [dtensor] add checkpointing example (#94743)
This PR adds some DTensor sharding example on a simple MLP model
for checkpointing reference purposes

Note that checkpointing itself is not implemented yet.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94743
Approved by: https://github.com/wz337
2023-02-16 22:04:09 +00:00