mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR adds support for non-functional collectives under `FakeTensorMode` and `fake_pg`. It helps eliminate the patching of collectives for memory and runtime estimation. It also modifies the `ModTracker` to enable the post-backward hook call for modules whose inputs don't require gradients but parameters do. For the memory tracking, we now enable tracking DTensor dispatcher for custom dispatch functions like `entropy_loss`. Dispatcher is only enabled for the memory tracking part and disabled as soon as it is done. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147566 Approved by: https://github.com/weifengpy |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| common_utils.py | ||
| fake_collectives.py | ||
| fsdp2_mem_tracker.py | ||
| ilp_utils.py | ||
| mem_tracker.py | ||
| memory_tracker.py | ||
| mod_tracker.py | ||
| runtime_estimator.py | ||
| sac_estimator.py | ||
| sac_ilp.py | ||