Commit Graph

259 Commits

Author SHA1 Message Date
Yue Dong
270ed13e87 [DTensor] Make DTensor from_local backward partial() to replicate() pass through (#115967)
Summary:
This change makes the `DTensor.from_local()` placements in backward pass from `Partial()` to `Replicate()` as pass through for following reasons:
1. When we run backward pass of DTensor.from_local, if the target placement is partial() (i.e. from user manual overwrite code instead of torch_dispatch) we keep the grad as replicate. This is because converting the gradients back to `Partial()` is meaningless.
2. The current div logic will lead to wrong numerical value in the above case.

Test Plan:
**CI**:
CI Tests

**Unit test**:
`buck2 test mode/dev-nosan //caffe2/test/distributed/_tensor:redistribute`
- Passed

**With model training**:
```
# We tested the case where input tensor is manually overwrite as Partial() and
# output tensor manually overwrite to Shard() then to local.

# Before the change: numerical value not correct
Forward pass:
    collective: ReduceScatter
backward pass:
    collective: AllGather + div by process group size

# After the change: div is removed as expected.
Forward pass:
    collective: ReduceScatter
Backward pas:
    collective: AllGather
```

Differential Revision: D52175709

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115967
Approved by: https://github.com/wanchaol
2023-12-19 00:16:10 +00:00
Yue Dong
87ea6fb844 Make input contiguous for DTensor reduce scatter to fix the incorrect numerical values (#115847)
Summary:
This change is to make the input tensor contiguous for DTensor reduce scatter in the case no padding is needed.

There's no exception thrown during training, but we ran into numerical value correctness issue without the change.

Test Plan:
**CI**
CI test

**WHEN model test**:
- Verified loss for each iteration within the expected range.
- Verified NE on-par with this change with 4B training data.

Differential Revision: D52170822

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115847
Approved by: https://github.com/wanchaol
2023-12-17 01:35:09 +00:00
Kumar Ashutosh
405a0040cf Adds tool to visualize sharding (#114307)
This pull request adds a tool to visualize sharding. It uses the device_mesh and placement details to construct a visualization of the split of a torch dtensor.

Things to fix:

- [x] This implementation only uses the first element of the placement tuple, when can there be more than one elements?
- [x] The calculation of the split is happening here but maybe it is already done somewhere internally in Shard class and can we directly call that here?

Fixes #108746

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114307
Approved by: https://github.com/wanchaol
2023-12-12 06:18:03 +00:00
Wanchao Liang
fbb744fd49 [dtensor] enable radam foreach optimizer (#115566)
As titled, test both non-foreach and foreach optim

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115566
Approved by: https://github.com/XilunWu
ghstack dependencies: #115297, #115564, #115565
2023-12-12 03:57:00 +00:00
Wanchao Liang
4bd661c472 [dtensor] enable adadelta foreach optimizer (#115564)
as titled

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115564
Approved by: https://github.com/XilunWu
ghstack dependencies: #115297
2023-12-12 03:56:55 +00:00
Wanchao Liang
8a27352d6b [dtensor] add a implicit replication flag (#115297)
This PR adds a experimental implicit replication support for DTensor to
inter-op with torch.Tensor, basically under this context manager DTensor
could work together with torch.Tensor by assuming the torch.Tensor
sharding layout is replicated.

Note that this is risky for DTensor so we don't turn it on by default,
but for certain cases where it is for sure replicated, user can use this
to allow DTensor and Tensor computation work together

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115297
Approved by: https://github.com/awgu
2023-12-12 03:56:48 +00:00
Wanchao Liang
0692240b90 [dtensor] account for empty list when turning to OpStrategy (#115298)
Trying to fix https://github.com/pytorch/pytorch/issues/115065

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115298
Approved by: https://github.com/XilunWu
2023-12-12 00:11:16 +00:00
Yue Dong
485ea9a70a [DTensor] Add DTensor experimental op for LayerNorm backward sharding rule propogation (#115398)
Summary: This diff is only for prototype to unblock the TP work. PyTorch distributed team is working on a more generic backward op for `aten.layer_norm`. Will remove this op from the experimental file once it is ready.

Test Plan:
**Local Test**:
Accuracy:
- Dtensor + Checkpoint: first run loss: P884569822 (on-par with baseline: P884213363)
- 2nd by loading saved checkpoint: P884583429 (on-par with baseline: P884271869)

Trace:
- Collective functions are inserted automatically.
- Example: https://fburl.com/perfdoctor/l567ww1x

**MAST Test**:
With: trainer = 128, batch_size=512
- NE on-par:
(see: 4441_ep_bs512_2fsdp_tp_sp_dtensor)
 {F1155318138}

Differential Revision: D51490868

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115398
Approved by: https://github.com/wanchaol
2023-12-09 09:38:56 +00:00
Wanchao Liang
1215f2ffe2 [dtensor] readme typo (#115383)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115383
Approved by: https://github.com/awgu
ghstack dependencies: #115365
2023-12-08 21:40:40 +00:00
Iris Zhang (PyTorch)
23fa9621e4 [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099) (#115193)
Summary:

Rename _device_mesh.py to device_mesh.py, update all callsites, add documentation.
We created stubs for public class and methods in torch.distributed.device_mesh so that torch.distributed.device_mesh can be imported with or without distributed is available().

Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/115099
Prior to landing, CI signals are all passed. Shipit added the "ci/trunk" label to the PR and DID NOT wait for it and went ahead committing. More context can be found in the reverted PR above.

Test Plan: CI.

Differential Revision: D51861018

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115193
Approved by: https://github.com/fegin
2023-12-08 08:44:32 +00:00
wz337
dacf5d6e92 [DTensor] Remove assert to allow tensor sharding dimension < Shard(x).ndim (#115114)
Consolidated by changes made by @yoyoyocmu. https://www.internalfb.com/diff/D51821717
Remove assert to allow tensor dimension < Shard(x).ndim. With the current padding, we do support this already.

Follow up: we will still need to fix the size mismatch and `full_tensor()` hang when tensor is uneven-sharded.
Created issue here: https://github.com/pytorch/pytorch/issues/115310

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115114
Approved by: https://github.com/yoyoyocmu, https://github.com/wanchaol
2023-12-07 21:57:30 +00:00
Wanchao Liang
6a6a1e3ef7 [dtensor] update README to make all example runnable (#115365)
as titled, also add torchrun commands

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115365
Approved by: https://github.com/fegin
2023-12-07 20:23:37 +00:00
Tobias Ringwald
43f42bf3cb Updated docs for deprecated torch.set_default_tensor_type (#115041)
Added deprecation note for torch.set_default_tensor_type. Updated docs that referenced this method.

Fixes #113646.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115041
Approved by: https://github.com/janeyx99
2023-12-07 16:17:36 +00:00
Xilun Wu
5e3631db31 [DTensor] force re-compute sharding when normalized_shape differs in fwd layer norm (#115250)
**Summary**:
#114174 did not test the case where `elementwise_affine=False` (i.e. `weight` and `bias` are `None`) and this test would fail due to cached sharding propagation. The difference on sharding prop between these cases is, when `weight` and `bias` are None, the forward layer norm op will be recognized as a "static shape op" and `propagate_op_sharding` will be applied rather than `propagate_op_sharding_non_cached`. A fix is to force re-compute sharding when `normalized_shape` changes by setting op schema's `RuntimeSchemaInfo`'s `static_argnum` to include `normalized_shape` (i.e. 1)

**Test**:
pytest test/distributed/_tensor/test_math_ops.py -s -k layer_norm

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115250
Approved by: https://github.com/wanchaol
2023-12-07 07:44:06 +00:00
Joel Schlosser
22704426c3 Expand dynamic dims support for traceable subclasses (#114311)
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).

Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
    * Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
    * Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
    * Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
    * Signatures now:
    ```python
    # attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
    # ctx is anything useful for rebuilding the class we want to guard on
    attrs, ctx = x.__tensor_flatten__()
    ...
    # inner_tensors is a dict of {attr -> tensor}
    # ctx is taken unmodified from flattening and (eventually) guarded on
    # outer_size is the expected size of the output; possibly symbolic
    # outer_stride is the expected strides of the output; possibly symbolic
    y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)

    # at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
    # the assert simplifies symbols when there are relationships between outer and inner symbols
    ```
    * Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
    * Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
    * Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
2023-12-05 21:09:25 +00:00
Xilun Wu
4c5fe66880 [DTensor][BE] fix bug in OpStrategy for Tuple output (#115161)
**Summary**:
DTensor sharding propagation returns an `OpStrategy` object in case of a
Tuple of multiple DTensors of the same `placements` and this object will later
be expanded to a tuple of `DTensorSpec`s. However, the expansion was done
as copying the object's reference instead of copying/creating new objects and
this leads to wrong overriding issue in Tensor Meta propagation logic.

**Test**:
pytest test/distributed/_tensor/test_math_ops.py
pytest test/distributed/_tensor/test_dtensor_ops.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115161
Approved by: https://github.com/wanchaol
2023-12-05 18:28:40 +00:00
Nikita Shulga
a827ac71f2 Revert "[DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099)"
This reverts commit eaa64339d6.
2023-12-05 08:59:36 -08:00
Iris Zhang (PyTorch)
eaa64339d6 [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099)
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.

Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/114991
It was failing because failing a public module binding tests in MacOS, and this is due to the change in import order for torch/distributed/fsdp/_common_utils.py. Since this original import would still work, we remove the changes in this file.

Test Plan: CI.

Differential Revision: D51825114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115099
Approved by: https://github.com/wanchaol, https://github.com/fegin
2023-12-05 05:44:52 +00:00
Wanchao Liang
288b1acaa9 [dtensor] fix empty shape init for dtensor constructors (#115091)
As titled, this PR fixes the empty shape init case, where if we pass in
things like `torch.dtensor.zeros([])`, it should call `torch.zeros([])`
under the hood not `torch.empty(0)`, this makes dtensor constructor and
torch constructor aligns

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115091
Approved by: https://github.com/XilunWu
2023-12-05 00:51:29 +00:00
PyTorch MergeBot
3a2e2044cd Revert "[DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#114710) (#114991)"
This reverts commit 729ac7317a.

Reverted https://github.com/pytorch/pytorch/pull/114991 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/114991#issuecomment-1837214567))
2023-12-02 17:55:51 +00:00
Wanchao Liang
28925902fa [TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.

In particular this PR:

* Make ParallelStyle to be a real contract API for parallelize_module to
  take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
  refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
  both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
  output_layouts to desired_input/output_layouts, group
  the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
  standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
  style, all prepare input/output functions) as we throw deprecation
 msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
  mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
  documentation more clear about what each style is doing

TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes

Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 08:18:12 +00:00
Iris Zhang (PyTorch)
729ac7317a [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#114710) (#114991)
Summary:

Same content of changes as https://github.com/pytorch/pytorch/pull/114710

Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.
ghstack-source-id: 208980207
exported-using-ghexport

Test Plan: CI.

Reviewed By: wanchaol

Differential Revision: D51629761

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114991
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/fegin
2023-12-02 04:39:41 +00:00
wz337
7b3e45be59 [DeviceMesh] Rename get_dim_groups to get_group (#114708)
Rename get_dim_groups to get_group and update all callsites.

Differential Revision: [D51629801](https://our.internmc.facebook.com/intern/diff/D51629801/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114708
Approved by: https://github.com/XilunWu, https://github.com/wanchaol, https://github.com/fegin
2023-11-30 23:40:14 +00:00
wz337
febbc48f43 [DeviceMesh] Make our mesh_dim kwarg naming consistent (#114707)
Changing size(self, dim: Optional[int] = None) to def size(self, mesh_dim: Optional[int] = None) so it is consistent with the rest of our APIs.

We also update this API usage change in both PT and internal (pyper, APS).

Differential Revision: [D51602986](https://our.internmc.facebook.com/intern/diff/D51602986/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114707
Approved by: https://github.com/XilunWu, https://github.com/wanchaol, https://github.com/fegin
2023-11-29 19:43:23 +00:00
Wanchao Liang
624f202522 [dtensor] add CommDebugMode for debugging (#113592)
This PR adds a CommDebugMode debugging tool to record the number of
distributed collectives, utilizing TorchDispatchMode, the idea borrows
from the FlopCounterMode and we can expand this later to make it more
feature complete like the FlopCounterMode

This is useful for debugging with DTensor and testing, in general this
fits for any complex distributed algorithms where it's non-trival to
understand the algorithm, we can use this tool to understand what
happened under the hood., we can later cover c10d collectives directly

Not sure if it would be a good general distributed debug tool yet,
so adding to the dtensor package first

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113592
Approved by: https://github.com/wconstab
2023-11-27 02:40:28 +00:00
Andrew Gu
34326e43eb [DTensor] Made DTensorSpec hash recomputation lazy (#114379)
If we assign `spec.tensor_meta = ...`, we do not have to recompute the hash eagerly. We just need to clear the existing hash so that the next call to `__hash__` recomputes it.

We found that the breakage of the DTensor + `torch.compile` tests comes from https://github.com/pytorch/pytorch/pull/114236 and are not directly related to the `DTensorSpec` hashing changes. We fix that in the following PR temporarily by passing `dynamic=False`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114379
Approved by: https://github.com/wanchaol
2023-11-23 05:45:18 +00:00
Andrew Gu
e7326ec295 [DTensor] Computed DTensorSpec hash lazily (#114322)
This is a forward fix for https://github.com/pytorch/pytorch/issues/113781.

We lazily compute the hash so that we do not try to compute the hash on `SymInt`s (for the stride) during Dynamo tracing.

Tested via:
```
python test/distributed/_tensor/test_dtensor_compile.py -k test_2d_fsdp_tp_ac_compile
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114322
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925, #113930, #114141, #113915, #114140
2023-11-22 04:13:11 +00:00
Andrew Gu
7694b05416 [DTensor] Reduced to one isinstance call in is_shard (#114140)
This is a nit change to save one `isinstance` call for when `dim` is not `None` but the placement is not `Shard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114140
Approved by: https://github.com/Skylion007, https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925, #113930, #114141, #113915
2023-11-21 17:31:02 +00:00
Wanchao Liang
bbc39b7bb4 [dtensor] enable RMSprop optimizer foreach support (#114152)
as titled

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114152
Approved by: https://github.com/XilunWu
ghstack dependencies: #114149, #114150, #114151
2023-11-21 03:23:40 +00:00
Wanchao Liang
bcd310a7ad [dtensor] enable adagrad foreach support (#114151)
This PR enables the adagrad foreach mode support

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114151
Approved by: https://github.com/XilunWu
ghstack dependencies: #114149, #114150
2023-11-21 03:23:40 +00:00
Andrew Gu
3e49621f3b [DTensor] Cached hash for DTensorSpec (#113915)
**Overview**
Generally, I think we can try to freeze as many of these classes used in DTensor sharding propagation as possible so that we can cache hashes. This PR targets hashing `DTensorSpec`, which turns out to be relatively expensive.

**Details**
It looks like `tensor_meta` is only updated in `_wrap_output_spec_tensor_meta`, which only runs if the propagation was not cached:
ae94c7e491/torch/distributed/_tensor/sharding_prop.py (L137)
ae94c7e491/torch/distributed/_tensor/sharding_prop.py (L153)
In that case, I think we can cache the hash for the `DTensorSpec` and only update it when one of the hashed attributes changes, which we only really expect to happen for `tensor_meta`.

To ensure correctness, we need that all hashed attributes are immutable.
- `DeviceMesh` caches its hash: a9134fa99a/torch/distributed/_device_mesh.py (L181)
- This PR makes each `Placement` a frozen `dataclass`, making them immutable (relying on the fact that they do not have references to any mutable objects).
- `TensorMeta` is a `NamedTuple` of `torch.Size`, `Tuple[int, ...]`, and `torch.dtype`, so it is immutable: 9916d8a9ea/torch/distributed/_tensor/placement_types.py (L369-L375)

**Example**
For some simple small GPT model:
Before: 0.125 ms
<img width="509" alt="Screenshot 2023-11-16 at 10 08 05 PM" src="https://github.com/pytorch/pytorch/assets/31054793/10e59401-f635-431f-80b5-1b48df3a706e">

After: 0.048 ms
<img width="294" alt="Screenshot 2023-11-16 at 10 08 47 PM" src="https://github.com/pytorch/pytorch/assets/31054793/09a3b0b9-f68c-4afc-bca1-c29a4b01c2fb">

The overall Adam CPU step time decreases from 7.647 ms to 6.451 ms.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113915
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925, #113930, #114141
2023-11-21 01:24:21 +00:00
Andrew Gu
fb25fd6f86 [DTensor] Replaced neg dim normalization with assert in helper (#114141)
This is a replacement for https://github.com/pytorch/pytorch/pull/113922. I think we can still leave the check for negative shard dimension in `compute_local_shape_and_global_offset` and replace the normalization logic with an assert. This should provide us a stack trace to see which user-facing API did not normalize the dim as expected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114141
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925, #113930
2023-11-21 01:24:21 +00:00
Andrew Gu
c39c69953f [DTensor] Used new placements for neg dim in distribute_tensor (#113930)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113930
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925
2023-11-20 22:32:58 +00:00
Andrew Gu
e2095a04ae [DTensor] Ensured grad_placements was tuple (#113925)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113925
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134
2023-11-20 22:32:58 +00:00
Andrew Gu
f4ffd46c08 [DTensor] Used new placements for neg dim in from_local (#114134)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114134
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924
2023-11-20 22:32:51 +00:00
Andrew Gu
b41ad7d695 [DTensor] Used new placements for neg dim in redistribute (#113924)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113924
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919
2023-11-20 22:30:16 +00:00
Andrew Gu
77e058f055 [DTensor] Made _Partial, Replicate frozen dataclasses (#113919)
This is part of the larger stack to work toward being able to cache hashes for `DTensorSpec`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113919
Approved by: https://github.com/wanchaol
2023-11-20 22:28:47 +00:00
KingsleyLiu-NV
cd2798943d [dtensor] support convolution ops (#113123)
This PR creates a prototype of training convolutional neural networks based on DTensor.

- Register required ops and implement operator dispatch
- Add unit tests and example

Basically, we shard the activations and replicate the model weights in this prototype. We can scale out to multiple GPUs and reduce the per-GPU memory footprint with this approach, and achieve weak scaling in terms of training performance (i.e., time per iteration).

Reference log (on 2xA100 GPU):

Unit Test
```bash
root@luna-prod-78-80gb:/pytorch# python3 test/distributed/_tensor/test_convolution_ops.py
/opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (Triggered internally at /opt/conda/conda-bld/pytorch_1699257304556/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2170.)
  return F.conv2d(input, weight, bias, self.stride,
/opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (Triggered internally at /opt/conda/conda-bld/pytorch_1699257304556/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2170.)
  return F.conv2d(input, weight, bias, self.stride,
..
----------------------------------------------------------------------
Ran 2 tests in 30.354s

OK
root@luna-prod-78-80gb:/pytorch# python3 test/distributed/_tensor/test_other_ops.py
[rank0]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank0]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank1]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank1]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
...
----------------------------------------------------------------------
Ran 3 tests in 16.343s

OK
```
ConvNeXt Example
```bash
root@luna-prod-78-80gb:/pytorch# python3 torch/distributed/_tensor/examples/convnext_example.py
rank 3, 20 iterations, latency     584.80 ms, forward     102.84 ms, backward     297.80 ms, max reserved    16.34 GiB, max allocated    14.75 GiB
rank 1, 20 iterations, latency     584.64 ms, forward     104.85 ms, backward     297.60 ms, max reserved    16.40 GiB, max allocated    14.74 GiB
rank 0, 20 iterations, latency     584.48 ms, forward     104.64 ms, backward     297.90 ms, max reserved    16.39 GiB, max allocated    14.75 GiB
rank 2, 20 iterations, latency     584.96 ms, forward      93.21 ms, backward     297.95 ms, max reserved    16.40 GiB, max allocated    14.74 GiB
```

@wanchaol @fduwjj FYI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113123
Approved by: https://github.com/wanchaol
2023-11-20 21:01:28 +00:00
Andrew Gu
99b89db174 [DTensor] Added op_call in no-mesh dispatch assert message (#113903)
This helps debug, e.g. when there is an unsupported op.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113903
Approved by: https://github.com/wanchaol
ghstack dependencies: #113654
2023-11-17 02:44:54 +00:00
Wanchao Liang
ae94c7e491 [dtensor] add foreach_zero_ support (#113897)
This PR add foreach_zero_ op support, to fix when
optim.zero_grad(set_to_none=False) hit this op and erroring out the
device mesh not found issue.

Also move the test to use zero_grad as the last step as that's when we
going to have dtensor as grads

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113897
Approved by: https://github.com/awgu
2023-11-17 02:11:19 +00:00
Wanchao Liang
2ac33ad98a [dtensor] group dispatch unwrapping to a method (#113846)
This PR group the dispatch unwrapping logic to a method, so that even
custom handlers can reuses many parts of the dispatch logic to do custom
things.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113846
Approved by: https://github.com/wz337
2023-11-16 23:54:18 +00:00
PyTorch MergeBot
277474f1a0 Revert "[2d] pass shape/stride during tensor unflatten (#113547)"
This reverts commit 93372455a7.

Reverted https://github.com/pytorch/pytorch/pull/113547 on behalf of https://github.com/wanchaol due to broken compile test ([comment](https://github.com/pytorch/pytorch/pull/113547#issuecomment-1813048318))
2023-11-15 18:32:54 +00:00
Wanchao Liang
93372455a7 [2d] pass shape/stride during tensor unflatten (#113547)
as titled, built on top of the work @wz337 enabled, this could save some
runtime CPU time to recreate DTensor parameters with correct
shape/stride, and avoid issues when un-even sharding parameters

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113547
Approved by: https://github.com/XilunWu
ghstack dependencies: #113323, #113324
2023-11-14 09:28:09 +00:00
Wanchao Liang
b16e3b5373 [funcol] add two APIs: wait() and numpy() (#113323)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113323
Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/wconstab
2023-11-14 09:27:45 +00:00
Wanchao Liang
6ed20af10e [dtensor] refactor op dispatch and fix is_same_size/equal (#112927)
torch.equal/is_same_size currently skips sharding prop and directly do
local tensor compute, this is wrong. for these two ops:

- torch.equal: should not skip sharding prop, need to have two DTensor
have the SAME sharding before compare local shard values
- torch.is_same_size: need to completely skip both sharding prop and
local compute

This PR refactors the existing op_dispatch to make it a class instance
so that we can do custom op handling, then fixes both torch.equal and
torch.is_same_size

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112927
Approved by: https://github.com/fduwjj, https://github.com/XilunWu
2023-11-13 22:46:31 +00:00
PyTorch MergeBot
23e0923c74 Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"
This reverts commit eeeb40b327.

Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to Reverting this pr as the one under it in the stack is causing regressions in torchrec ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1806044435))
2023-11-10 16:30:36 +00:00
Xuehai Pan
eeeb40b327 [pytree] reorganize submodule structure for C++ and Python pytree (#112278)
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.

Before:

```text
torch
├── utils
│   ├── _pytree.py
│   ├── _cxx_pytree.py
│   ...
...
```

After:

```text
torch
├── utils
│   ├── _pytree
│   │   ├── __init__.py
│   │   └── api
│   │       ├── __init__.py
│   │       ├── cxx.py
│   │       └── python.py
│   ...
...
```

The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
2023-11-10 05:41:32 +00:00
PyTorch MergeBot
bf452dcde6 Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"
This reverts commit fa895da968.

Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to in the bottom diff in the stack changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1804870560))
2023-11-10 00:12:52 +00:00
NVS Abhilash
44c0521e8c fix: docstring error in torch/distributed module (#113241)
Fixes: #113193

`pydocstyle <all_files_in_issue> --count`

- Before: 345
- After: 130

For deprecated methods, I have added a `noqa` to ignore them. I was not able to find the file `torch/distributed/tensor/parallel/multihead_attention_tp.py`, so I've ignored it for this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113241
Approved by: https://github.com/kit1980
2023-11-09 19:10:20 +00:00
Wanchao Liang
9834fb7fd0 [dtensor] full_tensor to return synchronously (#113322)
full_tensor API should return synchronously instead of
AsyncCollectiveTensor and if the return is that, we do the wait
directly, this makes the full_tensor API be more percise
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113322
Approved by: https://github.com/wz337
2023-11-09 18:02:40 +00:00