Currently, when we have 2D composition, a global variable _extensions controls the 2D deviation we need to take in state_dict calls (See https://github.com/pytorch/pytorch/blob/release/2.1/torch/distributed/fsdp/_fsdp_extensions.py#L66-L68). This is problematic when we have both a 2D model and a plain FSDP model in the same dist environment, as the _extensions will be mistakenly turned on for the plain FSDP model, resulting in state_dict error (RuntimeError: No parent device_mesh is found for FSDP device_mesh.).
This PR binds _fsdp_extension to the FSDP instances to make sure that state_dict calls would not get interfered with each other when mixing both 2D and 1D parallelism.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113237
Approved by: https://github.com/fduwjj, https://github.com/fegin
Even after PR #111431, the `collective(...)` function still uses the underlined version `avoidRecordStreams_` inside and does not respect each collective call's preference, as the underlined `avoidRecordStreams_` is only controlled by environment variable.
As a fix, we pass `avoidRecordStreams` into the collective() function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112195
Approved by: https://github.com/awgu