mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PP] Refactor test_schedule_multiproc (#158780)
This refactors the pipelining schedule tests since a lot of them have the same repeated code of: 1. Create pipelined model and reference model 2. Run reference model and pipelined model 3. compare gradients So this refactors those parts above into helper methods and reduces ~300 LOC. Also adds a better gradient check to resolve flakiness (fixes https://github.com/pytorch/pytorch/issues/154408). Pull Request resolved: https://github.com/pytorch/pytorch/pull/158780 Approved by: https://github.com/wconstab
This commit is contained in:
parent
3967dbedf4
commit
b0b3e6e48b