This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.
This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms

after (with this change), aten.addmm latency: 0.341ms

overall one layer of mlp time reduced from 13.535 -> 9.665ms
Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
This is the first series of PR that adopts operator impls to use a
strategy based approach, each op utilizes OpStrategy and PlacementStrategy
to generate their own strategy. By utilizing the strategy based
approach along with the op graph, we could enable more advanced op
implementation (decomp is possible), and turn the sharding prop to be
more like a contraint satisfication problem.
This PR alone only adds some basic tensor op strategies, and it directly
works on the op graph that was used for metadata propagation. The tensor ops
added in this PR mainly follows one of the arg strategy. The next set of
PRs would add more op strategies to other ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100607
Approved by: https://github.com/XilunWu
To make TP more generic for Attention module, we come up with this new col/rowwise parallel style.
Basically, the idea behind is that:
We only do DTensor op for Col/Rowwise sharded part. For the rest of ATen ops, we will leave it to Tensor ops.
And we set this behavior as default for Colwise and Rowwise parallel style. If people want to customize it, they can always pass in different prepare_input or prepare_output
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100508
Approved by: https://github.com/wanchaol