mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR adds support for non-functional collectives under `FakeTensorMode` and `fake_pg`. It helps eliminate the patching of collectives for memory and runtime estimation. It also modifies the `ModTracker` to enable the post-backward hook call for modules whose inputs don't require gradients but parameters do. For the memory tracking, we now enable tracking DTensor dispatcher for custom dispatch functions like `entropy_loss`. Dispatcher is only enabled for the memory tracking part and disabled as soon as it is done. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147566 Approved by: https://github.com/weifengpy |
||
|---|---|---|
| .. | ||
| test_fake_collectives.py | ||
| test_fsdp2_mem_tracker.py | ||
| test_mem_tracker.py | ||
| test_memory_tracker.py | ||
| test_mod_tracker.py | ||
| test_runtime_estimator.py | ||
| test_sac_estimator.py | ||
| test_sac_ilp.py | ||