This PR refactors the schema_suggestions in OuputSharding to be a single
OpSchema instead of list of schemas, which in practice we only have one,
for the multiple resharding case we also moved to OpStrategy so there's
no case that needs it to be a list
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122929
Approved by: https://github.com/tianyu-l
Fixes#113191
```
pydocstyle torch/distributed/fsdp/fully_sharded_data_parallel.py --count
```
On master: 80
After my changes on this PR: 3
```
pydocstyle torch/distributed/_spmd/comm_tensor.py --count
```
On master: 5
After my changes on this PR: 3
```
pydocstyle torch/distributed/_spmd/experimental_ops.py --count
```
On master: 3
After my changes on this PR: 1
```
pydocstyle torch/distributed/_spmd/iter_graph_module.py --count
```
On master: 39
After my changes on this PR: 27
```
pydocstyle torch/distributed/_spmd/graph_utils.py --count
```
On master: 16
After my changes on this PR: 4
```
pydocstyle torch/distributed/_spmd/distribute.py --count
```
On master: 19
After my changes on this PR: 10
```
pydocstyle torch/distributed/_spmd/api.py --count
```
On master: 10
After my changes on this PR: 3
```
pydocstyle torch/distributed/_spmd/batch_dim_utils.py --count
```
On master: 14
After my changes on this PR: 3
```
pydocstyle torch/distributed/_spmd/data_parallel.py --count
```
On master: 34
After my changes on this PR: 2
```
pydocstyle torch/distributed/_spmd/graph_optimization.py --count
```
On master: 35
After my changes on this PR: 13
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113216
Approved by: https://github.com/ezyang
This PR switches the usage of fx's shape prop TensorMetadata to
dtensor's own dedicated defined TensorMeta, this is because DTensor
only cares three fields: shape/stride/dtype, all other fields are not
necessary and can be inferred from local_tensor directly. This would
help significantly simplify how we deal with the tensor metadata by not
caring other fields.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108261
Approved by: https://github.com/fduwjj
ghstack dependencies: #107306
function schema doesn't provide us anything as we can also get the schema from `op._schema`, include the op directly in op_schema makes easier for sharding prop to do fake execution, and in principle it should also make the hash comparison faster as we don't need to hash the function schema, instead we just hash the `id(op)` which is constant
This PR is just a refactor to include op to OpSchema instead of func schema, no other logic changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107306
Approved by: https://github.com/fduwjj
My first attempt was to apply the same solution as how proxy_tensor.py
handles other inplace ops. However, foreach is different in the way
that it's schema is `native_functions.yaml` does not return anything,
whereas ops like `addcmul_` and `addcdiv_` do return Tensors (Thanks
bdhirsh for teaching me this!). As a result, the proxy output
during tracing does not wrap anything, and hence we cannot correctly
connect it with subsequent operators. Modifying `native_functions.yaml`
is not a preferred solution. After discussing with bdhirsh, the
temporary solution is to do foreach functionalization as a graph
pass for now. Later, when https://github.com/pytorch/pytorch/issues/97852
is addressed, we will switch to default functionalization.
Edit: the latest version follows @bdhirsh 's suggestion on using
`make_fx` `decomposition_table` instead of implementing manual
fx.Graph tranforms to functionalize `_foreach_add_`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97853
Approved by: https://github.com/fegin, https://github.com/wanchaol