_functional_collectives.py: Ensure we always wait all collectives.
derivatives.yaml: mark all_reduce as non differentiable
gen_variable_type.py: Add all_reduce to DONT_ENFORCE_TENSOR_IMPL_USE_COUNT
common_dtensor.py: replace dist.barrier with all_reduce
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95897
Approved by: https://github.com/wconstab, https://github.com/fegin
Inductor implementations of collectives/wait must match
eager impls in _functional_collectives in terms of interacting
with _register_tensor_work API. If they do, then splitting
a collective-wait pair so one half is in a compiled graph should
work fine.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95893
Approved by: https://github.com/kumpera
BC: This changes the signature and semantics of DeviceMesh::all_reduce.
DeviceMesh::all_reduce now uses a functional collective under the hood which makes it more easily traceable.
You no longer need to use CommTensor to get a trace.
all_reduce now is async only and uses AsyncCollectiveTensor to ensure proper stream synchronization.
Signature changed: removed `async_op` param and changes return type from `Optional[Work]` to `torch.Tensor`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95009
Approved by: https://github.com/wanchaol