Commit Graph

8 Commits

Author SHA1 Message Date
Xuehai Pan
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
Aaron Orenstein
00ffeca1b1 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-21 04:23:29 +00:00
PyTorch MergeBot
6374332d33 Revert "PEP585 update - torch/distributed (#145164)"
This reverts commit 6cb186e279.

Reverted https://github.com/pytorch/pytorch/pull/145164 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing an inductor test ([comment](https://github.com/pytorch/pytorch/pull/145164#issuecomment-2602875679))
2025-01-20 16:46:46 +00:00
Aaron Orenstein
6cb186e279 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-20 00:19:01 +00:00
Sanket Jayant Purandare
6508f0f5d4 Improved backward tracking and attribution, fixed typing for python < 3.10 (#129400)
For #125323
* Fixes typing for python < 3.10
* Fixes #129390

For #124688
* Improved attribution by registering `register_hook` and `post_accumulate_grad_hook` on params.
* Fixed pre-mature per module bw peak state initialization for AC.
* This improves per-module stats, global `peak_mem` was already accurate and remains unaffected.

For #128508
* When AC is applied to a `mod (nn.Module)` the backward order of execution is `pre-bw -> pre-fw -> post-fw -> post-bw`. Since the `ModTracker` maintains the `parents` attribute as set, the `post-fw` during backward was prematurely removing it from parents.
* With the fix we now maintain a per-module counter and only remove a module from `parents` when its counter goes to 0.
* Added tests to ensure this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129400
Approved by: https://github.com/awgu, https://github.com/huydhn
2024-06-25 10:54:58 +00:00
Sanket Jayant Purandare
2e5366fbc0 Extended Module Tracker (#128508)
This is an extension of [ModuleTracker](https://github.com/pytorch/pytorch/blob/main/torch/utils/module_tracker.py) with added features and bug fixes.

1. Allows installing user-defined hooks to be called in pre-fw, post-fw, pre-bw and post-bw hooks of the ``ModTracker``.
2. Adds a function ``get_known_fqn`` that retrieves the fqn of the module as tracked by the ``ModTracker``.
3. Only registers the multi-grad hooks if we are in the forward pass. This is important because, a module's pre-fw and post-fw hooks get called in the backward during AC and we do not want to register multi-grad hooks in this case.
4. Sets the kwarg ``always_call=True`` for post-fw hooks, so that they are called post AC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128508
Approved by: https://github.com/wanchaol
2024-06-14 19:48:46 +00:00
PyTorch MergeBot
f75f5987aa Revert "Extended Module Tracker (#128508)"
This reverts commit 1f46284f9e.

Reverted https://github.com/pytorch/pytorch/pull/128508 on behalf of https://github.com/malfet due to Broke lint, see https://github.com/pytorch/pytorch/actions/runs/9515753429/job/26230639980 ([comment](https://github.com/pytorch/pytorch/pull/128508#issuecomment-2168405784))
2024-06-14 16:46:03 +00:00
Sanket Jayant Purandare
1f46284f9e Extended Module Tracker (#128508)
This is an extension of [ModuleTracker](https://github.com/pytorch/pytorch/blob/main/torch/utils/module_tracker.py) with added features and bug fixes.

1. Allows installing user-defined hooks to be called in pre-fw, post-fw, pre-bw and post-bw hooks of the ``ModTracker``.
2. Adds a function ``get_known_fqn`` that retrieves the fqn of the module as tracked by the ``ModTracker``.
3. Only registers the multi-grad hooks if we are in the forward pass. This is important because, a module's pre-fw and post-fw hooks get called in the backward during AC and we do not want to register multi-grad hooks in this case.
4. Sets the kwarg ``always_call=True`` for post-fw hooks, so that they are called post AC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128508
Approved by: https://github.com/wanchaol
2024-06-14 12:01:53 +00:00