This PR switches the usage of fx's shape prop TensorMetadata to
dtensor's own dedicated defined TensorMeta, this is because DTensor
only cares three fields: shape/stride/dtype, all other fields are not
necessary and can be inferred from local_tensor directly. This would
help significantly simplify how we deal with the tensor metadata by not
caring other fields.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108261
Approved by: https://github.com/fduwjj
ghstack dependencies: #107306
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.
This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms

after (with this change), aten.addmm latency: 0.341ms

overall one layer of mlp time reduced from 13.535 -> 9.665ms
Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
Currently, aten.expand always expands to the global dimension. Then, it
introduces additional slice and clone ops before running compute on
the expanded tensor with a local tensor.
In this commit, if we detect the op consumes a SymInt size, it respects
both local size and the dimension placements from where the SymInt was
extracted.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99058
Approved by: https://github.com/wanchaol
This is a quick fix/hack to get around with the issue that some
"global" tensor view operation is invalid, but somehow it get
triggered by some models as mini-batch input itself won't have this
issue.
Since ultimately we should remove the dtensor expand and use the new
expansion, this hack is only temporary to unblock
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98813
Approved by: https://github.com/yifuwang, https://github.com/mrshenli
According to profiling, the top two expensive operations in spmd expansion are propagate_op_sharding and make_fx (for every dispatcher op node). This PR makes the following changes to speed up spmd expansion:
- We are unneccessarily doing propagate_op_sharding twice for every op. Remove one.
- When no tensor redistribution is required, we only need to update non-tensor args of the node according to op_schema and avoid building a GraphModule just for the node.
On a DDP use cases + foreach Adam, this change speeds up spmd expansion by ~5x (~10 min -> ~2 min).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98389
Approved by: https://github.com/mrshenli
My first attempt was to apply the same solution as how proxy_tensor.py
handles other inplace ops. However, foreach is different in the way
that it's schema is `native_functions.yaml` does not return anything,
whereas ops like `addcmul_` and `addcdiv_` do return Tensors (Thanks
bdhirsh for teaching me this!). As a result, the proxy output
during tracing does not wrap anything, and hence we cannot correctly
connect it with subsequent operators. Modifying `native_functions.yaml`
is not a preferred solution. After discussing with bdhirsh, the
temporary solution is to do foreach functionalization as a graph
pass for now. Later, when https://github.com/pytorch/pytorch/issues/97852
is addressed, we will switch to default functionalization.
Edit: the latest version follows @bdhirsh 's suggestion on using
`make_fx` `decomposition_table` instead of implementing manual
fx.Graph tranforms to functionalize `_foreach_add_`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97853
Approved by: https://github.com/fegin, https://github.com/wanchaol
Mainly two fixes:
1. `make_fx` seems trace through DeviceMesh operations. This commit removes that from the DTensor expanded graph
2. During DTensor expansion, autograd complains about inplace changes on leaf node. This commit wraps entire DTensor expansion code with `torch.no_grad()`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97787
Approved by: https://github.com/wanchaol
This commit adds an entry point for full `train_step` tracing and
expansion. Model forward, backwrd, and optimizer step will be included
in one graph. DTensor expansion will be applied on top to insert
collective communications. Users can also provide an `Override`
implementation to skip non-traceable submodules and directly install
submodule logic to the DTensor-expanded graph by inserting `fx.Nodes`.
Differential Revision: [D44325177](https://our.internmc.facebook.com/intern/diff/D44325177)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97416
Approved by: https://github.com/yifuwang, https://github.com/wanchaol
This is relanding the troubling part of #95009 that caused a regression.
BC: This changes the signature and semantics of DeviceMesh::all_reduce.
DeviceMesh::all_reduce now uses a functional collective under the hood which makes it more easily traceable.
You no longer need to use CommTensor to get a trace.
all_reduce now is async only and uses AsyncCollectiveTensor to ensure proper stream synchronization.
Signature changed: removed async_op param and changes return type from Optional[Work] to torch.Tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95804
Approved by: https://github.com/fegin
BC: This changes the signature and semantics of DeviceMesh::all_reduce.
DeviceMesh::all_reduce now uses a functional collective under the hood which makes it more easily traceable.
You no longer need to use CommTensor to get a trace.
all_reduce now is async only and uses AsyncCollectiveTensor to ensure proper stream synchronization.
Signature changed: removed `async_op` param and changes return type from `Optional[Work]` to `torch.Tensor`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95009
Approved by: https://github.com/wanchaol