Commit Graph

2 Commits

Author SHA1 Message Date
albanD
c5e04a4479 More accurate is_bw and prompt parents cleanup for ModuleTracker utils (#125634)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125634
Approved by: https://github.com/soulitzer, https://github.com/Chillee
2024-05-07 20:57:36 +00:00
albanD
76a26a885d Add module tracker (#125352)
This does a few things that were originally a few PRs but I am on a new machine and don't have ghstack.
If it is too problematic to review, I can re-split, just let me know.
This does:
- Cleanup context manager use in test_flop_counter
- Remove need for mod argument in FlopCounterMode, warning about it
- Re-implement a Module tracker from scratch using global forward Module use and multi_grad_hook (we cannot use global backward Module hook because they don't look for nested Tensor and they're custom Function based instead of multi_grad_hook).
- Update FlopCouterMode to use the new ModuleTracker. All the existing test suite passes as-is (only changes there are new tests and refactoring mentioned above)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125352
Approved by: https://github.com/mikaylagawarecki
2024-05-04 18:33:35 +00:00