Commit Graph

43 Commits

Author SHA1 Message Date
Wanchao Liang
09f3e08bcc [dtensor][3/n] use dedicated TensorMeta instead of the fx one (#108261)
This PR switches the usage of fx's shape prop TensorMetadata to
dtensor's own dedicated defined TensorMeta, this is because DTensor
only cares three fields: shape/stride/dtype, all other fields are not
necessary and can be inferred from local_tensor directly. This would
help significantly simplify how we deal with the tensor metadata by not
caring other fields.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108261
Approved by: https://github.com/fduwjj
ghstack dependencies: #107306
2023-09-13 04:08:02 +00:00
Wanchao Liang
d8f2ef10a6 [dtensor][1/n] refactor op dispatch logic to reduce overhead (#107305)
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms
![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76)

after (with this change), aten.addmm latency: 0.341ms
![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f)

overall one layer of mlp time reduced from 13.535 -> 9.665ms

Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
2023-08-18 18:30:46 +00:00
Chien-Chin Huang
49c8a0cad0 [SPMD][BE] Remove the legacy tracing code (#100858)
Remove the legacy tracing code as it cause several test and benchmark issues.

Differential Revision: [D45649123](https://our.internmc.facebook.com/intern/diff/D45649123/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100858
Approved by: https://github.com/wanchaol
2023-05-11 23:08:27 +00:00
Chien-Chin Huang
33fba6ef07 [SPMD] Add arange and zeros to default factory ops (#100037)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100037
Approved by: https://github.com/mrshenli, https://github.com/wanchaol
2023-04-26 16:32:10 +00:00
Aaron Gokaslan
e2a3817dfd [BE] Enable C419 rule for any all shortcircuiting (#99890)
Apparently https://github.com/pytorch/pytorch/pull/78142 made torch.JIT allow for simple generator expressions which allows us to enable rules that replace unnecessary list comprehensions with generators in any/all. This was originally part of #99280 but I split it off into this PR so that it can be easily reverted should anything break.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99890
Approved by: https://github.com/justinchuby, https://github.com/kit1980, https://github.com/malfet
2023-04-25 15:02:13 +00:00
Edward Z. Yang
abdd1f4a38 Reuse tracing context and fake tensors from backwards in forwards (#99619)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99619
Approved by: https://github.com/wanchaol
2023-04-20 22:39:48 +00:00
Shen Li
ca89e7942a [SPMD][Easy] switch to tree_map_only to simplify code (#99547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99547
Approved by: https://github.com/fegin
2023-04-19 20:40:09 +00:00
Shen Li
e605b5df74 [SPMD] Add sym_stride to DSymInt (#99504)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99504
Approved by: https://github.com/fegin
2023-04-19 14:55:40 +00:00
Shen Li
2cb8a8d4cc [SPMD] Support DSymInt for slice_backward in SPMD expansion (#99501)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99501
Approved by: https://github.com/fegin
2023-04-19 14:55:40 +00:00
Shen Li
292296141a [SPMD] Support SymInt with non-op call_function nodes (#99420)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99420
Approved by: https://github.com/fegin
2023-04-19 14:55:37 +00:00
Shen Li
301be37091 Avoid import * from experimental_ops (#99363)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99363
Approved by: https://github.com/fegin
2023-04-19 14:55:30 +00:00
Shen Li
62a6d81143 [SPMD][Easy] Making typing consistent by replacing object with Any (#99332)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99332
Approved by: https://github.com/dracifer
2023-04-17 19:33:45 +00:00
Shen Li
19c2804614 [SPMD][EASY] Remove unnecessary torch.ops prefix (#99331)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99331
Approved by: https://github.com/dracifer
2023-04-17 19:33:45 +00:00
Shen Li
c69d54885a [SPMD][BE] Generalize factory ops support in SPMD expansion (#99233)
Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D45028740](https://our.internmc.facebook.com/intern/diff/D45028740)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99233
Approved by: https://github.com/yifuwang
2023-04-16 00:07:27 +00:00
Shen Li
6bb20822f5 [SPMD][BE] Remove deprecated aten.sym_numel branch (#99232)
Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D45028732](https://our.internmc.facebook.com/intern/diff/D45028732)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99232
Approved by: https://github.com/yifuwang
2023-04-16 00:07:27 +00:00
Shen Li
39be994913 [SPMD][BE] Consolidate DSymInt Branches (#99231)
Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D45028726](https://our.internmc.facebook.com/intern/diff/D45028726)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99231
Approved by: https://github.com/yifuwang
2023-04-16 00:07:24 +00:00
Shen Li
544cd8e134 [SPMD] Refactor DSize to DSymInt to enable sym_numel (#99206)
This commit uses `aten.arange.default` and `aten.arange.start` to
test `aten.sym_numel`.

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D45028715](https://our.internmc.facebook.com/intern/diff/D45028715)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99206
Approved by: https://github.com/yifuwang
2023-04-16 00:07:21 +00:00
Shen Li
bafb984022 [SPMD] Enable aten.full.default with SymInt on sharded dims (#99190)
Differential Revision: [D45028686](https://our.internmc.facebook.com/intern/diff/D45028686)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99190
Approved by: https://github.com/yifuwang
2023-04-16 00:07:18 +00:00
Shen Li
40aaacd4fa Respect sharded dimensions when aten expaned/view consumes SymInt values (#99058)
Currently, aten.expand always expands to the global dimension. Then, it
introduces additional slice and clone ops before running compute on
the expanded tensor with a local tensor.

In this commit, if we detect the op consumes a SymInt size, it respects
both local size and the dimension placements from where the SymInt was
extracted.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99058
Approved by: https://github.com/wanchaol
2023-04-14 13:54:05 +00:00
Shen Li
02d1cf51b6 [Easy] Clean up args remap for DTensor expansion (#99040)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99040
Approved by: https://github.com/wanchaol, https://github.com/fegin
2023-04-14 00:23:00 +00:00
Wanchao Liang
15686950b7 [spmd] quick fix on batch input view issue (#98813)
This is a quick fix/hack to get around with the issue that some
"global" tensor view operation is invalid, but somehow it get
triggered by some models as mini-batch input itself won't have this
issue.

Since ultimately we should remove the dtensor expand and use the new
expansion, this hack is only temporary to unblock
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98813
Approved by: https://github.com/yifuwang, https://github.com/mrshenli
2023-04-11 14:27:01 +00:00
Edward Z. Yang
9a8f71f23e Convert logging f-strings to use % format (#98697)
Codemod done with
https://gist.github.com/ezyang/2e8b0463cdc6be278478495b23ff0530 with
assistance from ChatGPT.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98697
Approved by: https://github.com/voznesenskym
2023-04-10 12:19:31 +00:00
Yifu Wang
970c08f92f [spmd expansion] support scalar_tensor (#98390)
scalar_tensor is a pure factory function that can't be handled by DTensor prop rule and needs to be currently handled in spmd expansion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98390
Approved by: https://github.com/mrshenli
2023-04-05 16:31:44 +00:00
Yifu Wang
0830808dde [spmd expansion] speed up expansion by ~5x (#98389)
According to profiling, the top two expensive operations in spmd expansion are propagate_op_sharding and make_fx (for every dispatcher op node). This PR makes the following changes to speed up spmd expansion:
- We are unneccessarily doing propagate_op_sharding twice for every op. Remove one.
- When no tensor redistribution is required, we only need to update non-tensor args of the node according to op_schema and avoid building a GraphModule just for the node.

On a DDP use cases + foreach Adam, this change speeds up spmd expansion by ~5x (~10 min -> ~2 min).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98389
Approved by: https://github.com/mrshenli
2023-04-05 16:31:40 +00:00
Yifu Wang
161f7c0b28 [spmd expansion] support torch.ops.aten.sym_numel (#98388)
The current logic assumes non-overload ops takes two arguments however torch.ops.aten.sym_numel takes one.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98388
Approved by: https://github.com/mrshenli
2023-04-05 16:31:36 +00:00
Wei Wang
1e3abda31a Revert "[spmd expansion] support torch.ops.aten.sym_numel (#98229)" (#98382)
This reverts commit 4d13fcddef.

Fixes diff train landing issue as the original diff was modified after the PR was merged in OSS.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98382
Approved by: https://github.com/kit1980
2023-04-05 04:07:58 +00:00
Kazuaki Ishizaki
6514d71add Fix typos under torch/distributed directory (#98225)
This PR fixes typos in comments and messages of `.py` files under `torch/distributed` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98225
Approved by: https://github.com/soulitzer, https://github.com/kit1980
2023-04-05 00:21:33 +00:00
Yifu Wang
4d13fcddef [spmd expansion] support torch.ops.aten.sym_numel (#98229)
The current logic assumes non-overload ops takes two arguments however torch.ops.aten.sym_numel takes one.

Differential Revision: [D44615037](https://our.internmc.facebook.com/intern/diff/D44615037/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98229
Approved by: https://github.com/mrshenli
2023-04-03 23:57:10 +00:00
Shen Li
02179827cb [Easy] Include SPMD and DTensor files in UFMT checks (#98148)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98148
Approved by: https://github.com/fegin
2023-04-02 15:34:49 +00:00
Shen Li
347c67d4a2 [Easy] Consolidate string startswith checks (#98147)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98147
Approved by: https://github.com/fegin
2023-04-02 04:02:37 +00:00
Shen Li
e8d39606eb [SPMD] Enable fused Adam in full train step tracing (#98113)
Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98113
Approved by: https://github.com/yifuwang, https://github.com/fegin
2023-04-01 15:54:13 +00:00
Shen Li
9ec6fdb29b Enable adam foreach in full train step tracing (#97897)
Main changes:

1. Registered several foreach ops to both meta and DTensor
2. Skip redundant getitem node when expanding foreach ops with DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97897
Approved by: https://github.com/wanchaol, https://github.com/fegin
2023-03-30 16:47:10 +00:00
Shen Li
379fb47654 [SPMD] Support foreach optimizers with functionalization (#97853)
My first attempt was to apply the same solution as how proxy_tensor.py
handles other inplace ops. However, foreach is different in the way
that it's schema is `native_functions.yaml` does not return anything,
whereas ops like `addcmul_` and `addcdiv_` do return Tensors (Thanks
bdhirsh for teaching me this!). As a result, the proxy output
during tracing does not wrap anything, and hence we cannot correctly
connect it with subsequent operators. Modifying `native_functions.yaml`
is not a preferred solution. After discussing with bdhirsh, the
temporary solution is to do foreach functionalization as a graph
pass for now. Later, when https://github.com/pytorch/pytorch/issues/97852
is addressed, we will switch to default functionalization.

Edit: the latest version follows @bdhirsh 's suggestion on using
`make_fx` `decomposition_table` instead of implementing manual
fx.Graph tranforms to functionalize `_foreach_add_`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97853
Approved by: https://github.com/fegin, https://github.com/wanchaol
2023-03-30 11:27:10 +00:00
Shen Li
5949d86bec [Easy] Remove unnecessary graph lint (#97815)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97815
Approved by: https://github.com/fegin
2023-03-29 00:41:00 +00:00
Shen Li
c39f1c1490 Allow DTensor to trigger collecives before inplace ops (#97787)
Mainly two fixes:

1. `make_fx` seems trace through DeviceMesh operations. This commit removes that from the DTensor expanded graph
2. During DTensor expansion, autograd complains about inplace changes on leaf node. This commit wraps entire DTensor expansion code with `torch.no_grad()`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97787
Approved by: https://github.com/wanchaol
2023-03-28 21:06:51 +00:00
Wanchao Liang
e9c4904915 [dtensor] remove custom dispatch op (#95629)
Since we removed all custom dispatch ops, we can safely delete this
table as we won't use it for other purposes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95629
Approved by: https://github.com/XilunWu
2023-03-28 02:25:45 +00:00
Shen Li
75fb0b6c9f Enable full train_step tracing and customizable dist graph expansion (#97416)
This commit adds an entry point for full `train_step` tracing and
expansion. Model forward, backwrd, and optimizer step will be included
in one graph. DTensor expansion will be applied on top to insert
collective communications. Users can also provide an `Override`
implementation to skip non-traceable submodules and directly install
submodule logic to the  DTensor-expanded graph by inserting `fx.Nodes`.

Differential Revision: [D44325177](https://our.internmc.facebook.com/intern/diff/D44325177)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97416
Approved by: https://github.com/yifuwang, https://github.com/wanchaol
2023-03-25 09:24:21 +00:00
Shen Li
021de486ff [Easy] Apply black to format _spmd files (#97534)
No real changes. Format code to prepare for the PR on top.

Differential Revision: [D44376380](https://our.internmc.facebook.com/intern/diff/D44376380)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97534
Approved by: https://github.com/wanchaol
2023-03-25 01:09:41 +00:00
Rodrigo Kumpera
7ebd816aab Switch DTensor to use funcol::all_reduce. (#95804)
This is relanding the troubling part of #95009 that caused a regression.

BC: This changes the signature and semantics of DeviceMesh::all_reduce.

DeviceMesh::all_reduce now uses a functional collective under the hood which makes it more easily traceable.
You no longer need to use CommTensor to get a trace.

all_reduce now is async only and uses AsyncCollectiveTensor to ensure proper stream synchronization.

Signature changed: removed async_op param and changes return type from Optional[Work] to torch.Tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95804
Approved by: https://github.com/fegin
2023-03-02 17:55:01 +00:00
Wanchao Liang
bb9a05b116 [dtensor] use tracing for metadata prop (#95456)
This PR uses tracing for metadata prop, so that we can get correct
shape/stride metadata without manual calculation by ourselves.

The follow up PR on this would be adopt tracing for the sharding
prop itself

Differential Revision: [D43643578](https://our.internmc.facebook.com/intern/diff/D43643578)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95456
Approved by: https://github.com/XilunWu
2023-02-28 17:54:22 +00:00
PyTorch MergeBot
d950f45577 Revert "[Functional Collectives] Migrate DeviceMesh::all_reduce to use functional all_reduce. (#95009)"
This reverts commit 0765dbc25e.

Reverted https://github.com/pytorch/pytorch/pull/95009 on behalf of https://github.com/jeanschmidt due to this PR is causing internal breakages. Check https://fburl.com/diff/me41urq8
2023-02-27 19:21:58 +00:00
Rodrigo Kumpera
0765dbc25e [Functional Collectives] Migrate DeviceMesh::all_reduce to use functional all_reduce. (#95009)
BC: This changes the signature and semantics of DeviceMesh::all_reduce.

DeviceMesh::all_reduce now uses a functional collective under the hood which makes it more easily traceable.
You no longer need to use CommTensor to get a trace.

all_reduce now is async only and uses AsyncCollectiveTensor to ensure proper stream synchronization.

Signature changed: removed `async_op` param and changes return type from `Optional[Work]` to `torch.Tensor`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95009
Approved by: https://github.com/wanchaol
2023-02-24 02:10:55 +00:00
Chien-Chin Huang
250c054bdd [SPMD] Pull the minimal working distribute API and SPMD module to PyTorch (#94802)
Pull the minimal working distribute API and SPMD module to PyTorch. The original code is on https://github.com/pytorch/tau/tree/main/spmd/compiler.

Other main contributors to the original code base: @anj-s, @lessw2020, @wanchaol @aazzolini

Differential Revision: [D43197230](https://our.internmc.facebook.com/intern/diff/D43197230/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94802
Approved by: https://github.com/anj-s, https://github.com/wanchaol
2023-02-16 00:36:16 +00:00