Enable some sensible flake8-simplify rules. Mainly wanted to enable the SIM101, and `yield from` SIM103 checks. @kit1980 since you wanted to be tagged on this CI check.
Enabling this check also helped flag one logical bug so it's definitely beneficial (also fixed in this PR).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97984
Approved by: https://github.com/ezyang
My first attempt was to apply the same solution as how proxy_tensor.py
handles other inplace ops. However, foreach is different in the way
that it's schema is `native_functions.yaml` does not return anything,
whereas ops like `addcmul_` and `addcdiv_` do return Tensors (Thanks
bdhirsh for teaching me this!). As a result, the proxy output
during tracing does not wrap anything, and hence we cannot correctly
connect it with subsequent operators. Modifying `native_functions.yaml`
is not a preferred solution. After discussing with bdhirsh, the
temporary solution is to do foreach functionalization as a graph
pass for now. Later, when https://github.com/pytorch/pytorch/issues/97852
is addressed, we will switch to default functionalization.
Edit: the latest version follows @bdhirsh 's suggestion on using
`make_fx` `decomposition_table` instead of implementing manual
fx.Graph tranforms to functionalize `_foreach_add_`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97853
Approved by: https://github.com/fegin, https://github.com/wanchaol
Mainly two fixes:
1. `make_fx` seems trace through DeviceMesh operations. This commit removes that from the DTensor expanded graph
2. During DTensor expansion, autograd complains about inplace changes on leaf node. This commit wraps entire DTensor expansion code with `torch.no_grad()`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97787
Approved by: https://github.com/wanchaol
This commit adds an entry point for full `train_step` tracing and
expansion. Model forward, backwrd, and optimizer step will be included
in one graph. DTensor expansion will be applied on top to insert
collective communications. Users can also provide an `Override`
implementation to skip non-traceable submodules and directly install
submodule logic to the DTensor-expanded graph by inserting `fx.Nodes`.
Differential Revision: [D44325177](https://our.internmc.facebook.com/intern/diff/D44325177)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97416
Approved by: https://github.com/yifuwang, https://github.com/wanchaol