pytorch/test/distributed/pipelining
Howard Huang b3ef0c99f5 [PP] Fix zero bubble composability with DP (#134052)
Moved all the backward functions (`stage_backward_input`, `stage_backward_weight`, `stage_backward`) under the same `backward_maybe_with_nosync` function which controls the logic of the data parallel wrappers.

FSDP was not working with zero bubble PP because there will be twice as many "backward" calls and we update the weight gradients after `autograd.grad` is called. As a result, we need to manually call the FSDP `post_backward_hook()` after the weights have the correct gradients.

Fixes the tests:
`python test/distributed/_composable/test_composability/test_pp_composability.py ComposabilityTest.test_manual_with_data_parallel_dp_type_FSDP_ScheduleClass0_use_new_runtime_False`

`python test/distributed/_composable/test_composability/test_pp_composability.py ComposabilityTest.test_manual_with_data_parallel_dp_type_DDP_ScheduleClass0_use_new_runtime_False`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134052
Approved by: https://github.com/kwen2501
2024-09-04 23:46:29 +00:00
..
__init__.py [pipelining] Consolidate test models into a registry (#126114) 2024-05-14 19:11:54 +00:00
model_registry.py Make adding Buffers more like adding Parameters (#125971) 2024-07-31 10:32:40 +00:00
schedule_registry.py [PP] pt-native input/weight grad split (#132691) 2024-08-21 13:37:54 +00:00
test_backward.py [PP] Fix zero bubble composability with DP (#134052) 2024-09-04 23:46:29 +00:00
test_microbatch.py [pipelining] pipeline() taking microbatch as example input (#128163) 2024-06-07 15:51:53 +00:00
test_pipe.py Add None return type to init -- tests rest (#132376) 2024-08-01 15:44:51 +00:00
test_schedule_multiproc.py [PP] Add ZeroBubble schedule (#133467) 2024-08-22 13:32:15 +00:00
test_schedule.py [PP] Add get_schedule_class util (#132768) 2024-08-07 23:51:03 +00:00
test_stage.py [PP] pt-native input/weight grad split (#132691) 2024-08-21 13:37:54 +00:00
test_transformer.py [pipelining] pipeline() taking microbatch as example input (#128163) 2024-06-07 15:51:53 +00:00
test_unflatten.py [pipelining] pipeline() taking microbatch as example input (#128163) 2024-06-07 15:51:53 +00:00