Commit Graph

6 Commits

Author SHA1 Message Date
Peter Bell
66c32d099a Use pytree.arg_tree_leaves everywhere (#112394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112394
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393
2023-10-31 15:57:06 +00:00
Peter Bell
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
Wanchao Liang
942cd12d55 [spmd] add option to preserve node types (#100072)
This PR adds a option to preserve node types for the entire graph,
this could allow some exploration about using those node types to do
things like act checkpoint, etc.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100072
Approved by: https://github.com/mrshenli
2023-05-24 17:55:05 +00:00
Wanchao Liang
932ed333f7 [spmd] expose input_batch_dim to DataParallelMode (#99899)
This PR exposes the input batch dim to the DataParallelMode so that
we could have explicit control of which input dim is batch dim
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99899
Approved by: https://github.com/awgu, https://github.com/mrshenli
2023-04-25 19:30:58 +00:00
Wanchao Liang
e9bf94149e [spmd] Introduce Compile Mode FSDP with DTensor (#99062)
This PR introduces compile mode Data Parallel (FSDP/DDP) using DTensor sharding.

Along with the algorithm, it also introduces a new DataParallelMode so that `compile` API can take it
and apply data parallel. This PR trys to preserve the DTensorExpand
approach first to avoid BC, we shall discuss steps to remove
DTensorExpand.

The data parallel mode uses heuristics to determine node types in the
graphs and assign the corresponding sharding. The detailed algorithm
described in the design doc.

The benefits of this approach:
- Model parameters and optimizer states are all DTensors after  `spmd.compile`, which is necessary for FSDP, and also makes it super easier for checkpointing
- As model parameter/optim states are sharding in a per-parameter approach, it would be able to compose with sophisticated second order optimizer (i.e. Shampoo) in a easier way.
- We leverage the model parameter/grads information to derive data parallel pattern. In this way we don't need to worry about DTensor op coverage anymore! As data parallel is just a special case of DTensor operation.
- Use dtensor_expand might work for DDP but aren't going to work for FSDP as dtensor might choose to allgather activation, which might violate native fsdp algorithm.
- The approach is general enough to support both DDP/FSDP and a mixed mode

Follow ups:
- Add the "default" data parallel mode which supports mixing of
replicate/fully shard
- Test more e2e models with more different types of optimizers, etc
- migrate the existing stack from the DTensorExpand mode
- build optimizations on top of this prototype

Differential Revision: [D45174400](https://our.internmc.facebook.com/intern/diff/D45174400)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99062
Approved by: https://github.com/mrshenli
2023-04-22 03:13:05 +00:00
Wanchao Liang
b96bb2f1a6 [spmd] Introduce ParallelMode and add DTensorExpandMode (#98452)
This PR introduces a ParallelMode interface to define how to do
SPMD expansion and optimize the captured graph. This would be
beneifical for different parallelisms to expand differently
and apply different optimization passes

Put DTensorExpandMode as the first parallel mode that does the
existing dtensor_expand functionality.

Differential Revision: [D45174399](https://our.internmc.facebook.com/intern/diff/D45174399)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98452
Approved by: https://github.com/mrshenli
2023-04-21 17:24:54 +00:00