Commit Graph

4 Commits

Author SHA1 Message Date
Howard Huang
b3ef0c99f5 [PP] Fix zero bubble composability with DP (#134052)
Moved all the backward functions (`stage_backward_input`, `stage_backward_weight`, `stage_backward`) under the same `backward_maybe_with_nosync` function which controls the logic of the data parallel wrappers.

FSDP was not working with zero bubble PP because there will be twice as many "backward" calls and we update the weight gradients after `autograd.grad` is called. As a result, we need to manually call the FSDP `post_backward_hook()` after the weights have the correct gradients.

Fixes the tests:
`python test/distributed/_composable/test_composability/test_pp_composability.py ComposabilityTest.test_manual_with_data_parallel_dp_type_FSDP_ScheduleClass0_use_new_runtime_False`

`python test/distributed/_composable/test_composability/test_pp_composability.py ComposabilityTest.test_manual_with_data_parallel_dp_type_DDP_ScheduleClass0_use_new_runtime_False`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134052
Approved by: https://github.com/kwen2501
2024-09-04 23:46:29 +00:00
Howard Huang
4b1fb3b0ed [PP] pt-native input/weight grad split (#132691)
Add `stage_backward_input` and `stage_backward_weight` functions to perform the weight updates for inputs and weights independently.

We still support `self.dw_builder` argument for a custom backward, but it has become optional. It takes a separate code path and cannot be used in conjuction with the native zero backward.

Added tests:
`python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`
`python test/distributed/pipelining/test_backward.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132691
Approved by: https://github.com/wconstab
2024-08-21 13:37:54 +00:00
Aaron Orenstein
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
Ke Wen
52142192d4 [pipelining] Add stage backward function (#124958)
This is a helper function which:
1. computes the gradients for the stage inputs, and
2. accumulates gradients for the stage module's parameters.

A unit test for this function is also added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124958
Approved by: https://github.com/wconstab
ghstack dependencies: #124776, #124875
2024-05-01 07:56:58 +00:00