Commit Graph

8 Commits

Author SHA1 Message Date
Chien-Chin Huang
580b4702bc [FSDP][optim_state_dict] Consolidate the arguments and logic of optim_state_dict and optim_state_dict_to_load (#96534)
Summary:
The current `optim_state_dict()` does not require users to call `optim.state_dict()` first while `optim_state_dict_to_load()` requires users to call `optim.load_state_dict()`. This PR make both APIs provide the option for users not having to call the extra API.

This PR also changes the arguments order of `optim_state_dict_to_load` which is a breaking change. So we should do this asap before the API is adopted in production cases.

Test Plan: CI

Differential Revision: D43925068

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96534
Approved by: https://github.com/rohan-varma
2023-03-23 07:56:08 +00:00
Iris
6912cf4053 [DCP] Update DCP to use the updated FSDP optim state_dict APIs (#95303)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95303
Approved by: https://github.com/fegin
2023-02-23 03:55:02 +00:00
Iris
5fa937886c [DCP][nit] Rename variables + minor documentation fix for optimizer.py (#95264)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95264
Approved by: https://github.com/rohan-varma
2023-02-22 19:07:10 +00:00
Iris
92620aface [DCP]Update optimizer.py docstring (#94379)
Update load_sharded_optimizer_state_dict() docstring.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94379
Approved by: https://github.com/fduwjj
2023-02-09 20:24:28 +00:00
Iris
56db21aec1 [Checkpoint][Test] Add test for optimizer state_dict and resharding to 2d checkpoint test (#91092)
This PR updates the 2d checkpoint model state test to include:
1. optimizer state dict test
2. simple resharding test  (pg change)
3. rename test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91092
Approved by: https://github.com/fduwjj
2023-01-04 23:26:30 +00:00
joncrall
ad782ff7df Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
2022-12-29 05:32:42 +00:00
Iris
bfa223aaa6 [Checkpoint] Fix checkpoint test test_fsdp_optim_state.py (#91036)
This PR:
1. Fix the test/distributed/fsdp/test_fsdp_optim_state.py according to change in FSDP.flatten_sharded_optim_state_dict() API.
2. Update docstring accordingly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91036
Approved by: https://github.com/fegin
2022-12-17 03:02:31 +00:00
Iris
b8b7480065 [Checkpoint][2D][6/N] Add optimizer and update default_planner to core distributed (#90212)
This is the last PR for integrating 2D into core distributed.

This PR does the following:
1. Add optimizer.py: this adds ability to load a state_dict in conjunction with FSDP sharded optimzer state.
2. Update default_planner.py to support 2D checkpoint.
3. Add test_fsdp_optim_state.py as a unit test for No. 1.
4. Fix bug in torch/testing/_internal/distributed/checkpoint_utils.py
5. Rename the filename for the APIs that should be private. Will organize and cleanup further in following PRs. #90328

Docstring and integration test will be added in the following PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90212
Approved by: https://github.com/wanchaol
2022-12-08 02:53:29 +00:00