Commit Graph

17 Commits

Author SHA1 Message Date
c8ef
a989a0b13a [NFC] Fix some minor typos. (#145599)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145599
Approved by: https://github.com/Skylion007
2025-01-24 18:58:59 +00:00
Aaron Orenstein
00ffeca1b1 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-21 04:23:29 +00:00
PyTorch MergeBot
6374332d33 Revert "PEP585 update - torch/distributed (#145164)"
This reverts commit 6cb186e279.

Reverted https://github.com/pytorch/pytorch/pull/145164 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing an inductor test ([comment](https://github.com/pytorch/pytorch/pull/145164#issuecomment-2602875679))
2025-01-20 16:46:46 +00:00
Aaron Orenstein
6cb186e279 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-20 00:19:01 +00:00
bobrenjc93
08be9ec312 Migrate from Tuple -> tuple in torch/distributed (#144258)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144258
Approved by: https://github.com/aorenste
2025-01-10 08:34:54 +00:00
Howard Huang
2ac71a5771 [pipelining] add type checking to _backward functions (#140019)
fix https://github.com/pytorch/pytorch/issues/139405

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140019
Approved by: https://github.com/wconstab
2024-11-12 21:42:08 +00:00
Howard Huang
2c82f73647 [Pipelining] Clean up hooks in zero bubble (#138720)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138720
Approved by: https://github.com/wconstab
ghstack dependencies: #138119, #138504, #138735
2024-10-25 12:06:54 +00:00
Howard Huang
12755f45ff [Pipelining] small comments and variable renames (#138735)
Addressing the comments in previous PRs to update the variable names and add additional code comments

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138735
Approved by: https://github.com/wconstab
ghstack dependencies: #138119, #138504
2024-10-25 12:06:54 +00:00
Howard Huang
8945309c08 [Pipelining] fix extra memory usage in zero bubble (#138119)
Full debugging details in here: https://docs.google.com/document/d/1Pe_E0KWAfsJ6MCvKZ5aR28rTXX-rYLg13XxwXd6AALw/edit?usp=sharing

In zero bubble, we have two methods `stage_backward_input` and `stage_backward_weight`. During `stage_backward_input` we compute the gradients of the input with respect to the stage outputs and also retain the graph of the autograd graph (different than 1F1B where `retain_graph=False`). The output / loss was still being retained across the next schedule step() because we return the loss to the user and use the output to the next step. To allow autograd to free the variables in the graph we need to detach the output/loss after we don't need to use it autograd anymore.

Pre-fix:
<img width="1021" alt="image" src="https://github.com/user-attachments/assets/6c8bf469-32b1-4dac-85ff-b97991f9f0e3">

Post-fix:
<img width="1039" alt="image" src="https://github.com/user-attachments/assets/a1875038-e80b-4dd4-84f2-38727d7792dc">

without AC (7B model on titan):
10% memory improvement

with AC (7B model on titan)
50% memory improvement

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138119
Approved by: https://github.com/wconstab, https://github.com/kwen2501
2024-10-24 00:44:03 +00:00
Tom Ritchford
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
Howard Huang
ad6c70b656 [PP] Remove modifications to autograd nodes in ZB (#136678)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136678
Approved by: https://github.com/wconstab, https://github.com/kwen2501
ghstack dependencies: #136507, #136584
2024-09-27 07:07:58 +00:00
Howard Huang
141cae2eb8 [pipelining] Fix more leaks and check leaks in tests (#136584)
Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details).

This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress.

Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles.

Uses objgraph for a nice debug utility when a leak is found.

Credit to @H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak.

I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker.

Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py,
and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`:

```
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle?
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png
Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes)
Graph viewer (xdot) not found, generating a png instead
Image generated as /tmp/objgraph-ztz642h3.png
```

rendering of ` /tmp/objgraph-ztz642h3.png`:
<img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136584
Approved by: https://github.com/kwen2501, https://github.com/H-Huang
ghstack dependencies: #136507

Co-authored-by: Howard Huang <howardhuang@fb.com>
2024-09-26 01:10:40 +00:00
Will Constable
dbc3356655 [pipelining] fix py ref cycle in stage_backward (#136507)
TLDR; found forward activation tensors were being kept alive "forever"
(or until GC ran), and tracked it down to a cycle involving
`stage_backward.<locals>.extract_tensors_with_grads`.

The reference cycle in question is below.  (constructed using gc.get_referrers after doing a gc.collect in gc debug mode)

tensor is kept alive by
`[(<class 'cell'>, '0x7f7360234400')]`

tuple of cell objects
`(<cell at 0x7f73602343d0: function object at 0x7f734fff0ee0>, <cell at 0x7f7360234400: list object at 0x7f734e4d9a80>, <cell at 0x7f73602a4190: list object at 0x7f734eff8b00>)`
is kept alive by
`[(<class 'function'>, '0x7f734fff0ee0')]`

`<function stage_backward.<locals>.extract_tensors_with_grads at 0x7f734fff0ee0>`
is kept alive by
`[(<class 'cell'>, '0x7f73602343d0')]`

Put into more plain terms,

```

def stage_backward(...):
    ...
    stage_output_tensors = []

    # a cell object will exist that contains the variables defined in stage_backward and used by
    # both stage_backward and nested functions
    # in this case, the cell object contains 'stage_output_tensors' but

    # this function object will hold a reference to a 'cell' that contains any vars from
    # the parent scope not explicitly passed into the function as args.
    def extract_tensors_with_grads(...):
        ...
            # extract_tensors_with_grads refers to stage_output_tensors, so stage_output_tensors
            # is in the cell
            stage_output_tensors.append(output_val)
        ...
            # but extract_tensors_with_grads ALSO refers to itself (extract_tensors_with_grads),
            # so `extract_tensors_with_grads` will be in the cell
            extract_tensors_with_grads(...)
```

More debug details:
https://docs.google.com/document/d/1QPH1Lz0tnieIFPM2tyHrjVB-bjlnHuDgjx1p2am3cmE/edit?usp=sharing

In pdb:
```
gc.collect()
g = gc.garbage
g[-1]
[rank0]:(Pdb) [rank0]:<function
stage_backward.<locals>.extract_tensors_with_grads at 0x7fee5c3392d0>
g[-2]
[rank0]:(Pdb) [rank0]:(<cell at 0x7fee7abbcf40: function object at
0x7fee5c3392d0>, <cell at 0x7fee7abbcf70: list object at
0x7fee7ab68940>, <cell at 0x7fee5c3210c0: list object at 0x7fee5e1
d6340>)
g[-3]
[rank0]:(Pdb) [rank0]:[tensor([[[-4.1127e-06, -3.3826e-06,  2.6226e-06,
...,  6.4969e-06,
[rank0]:          -4.4405e-06, -4.7684e-06],
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136507
Approved by: https://github.com/awgu, https://github.com/kwen2501
2024-09-24 20:46:37 +00:00
Howard Huang
b3ef0c99f5 [PP] Fix zero bubble composability with DP (#134052)
Moved all the backward functions (`stage_backward_input`, `stage_backward_weight`, `stage_backward`) under the same `backward_maybe_with_nosync` function which controls the logic of the data parallel wrappers.

FSDP was not working with zero bubble PP because there will be twice as many "backward" calls and we update the weight gradients after `autograd.grad` is called. As a result, we need to manually call the FSDP `post_backward_hook()` after the weights have the correct gradients.

Fixes the tests:
`python test/distributed/_composable/test_composability/test_pp_composability.py ComposabilityTest.test_manual_with_data_parallel_dp_type_FSDP_ScheduleClass0_use_new_runtime_False`

`python test/distributed/_composable/test_composability/test_pp_composability.py ComposabilityTest.test_manual_with_data_parallel_dp_type_DDP_ScheduleClass0_use_new_runtime_False`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134052
Approved by: https://github.com/kwen2501
2024-09-04 23:46:29 +00:00
Howard Huang
4b1fb3b0ed [PP] pt-native input/weight grad split (#132691)
Add `stage_backward_input` and `stage_backward_weight` functions to perform the weight updates for inputs and weights independently.

We still support `self.dw_builder` argument for a custom backward, but it has become optional. It takes a separate code path and cannot be used in conjuction with the native zero backward.

Added tests:
`python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`
`python test/distributed/pipelining/test_backward.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132691
Approved by: https://github.com/wconstab
2024-08-21 13:37:54 +00:00
Aaron Orenstein
7c12cc7ce4 Flip default value for mypy disallow_untyped_defs [6/11] (#127843)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127843
Approved by: https://github.com/oulgen
ghstack dependencies: #127842
2024-06-08 18:49:29 +00:00
Ke Wen
52142192d4 [pipelining] Add stage backward function (#124958)
This is a helper function which:
1. computes the gradients for the stage inputs, and
2. accumulates gradients for the stage module's parameters.

A unit test for this function is also added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124958
Approved by: https://github.com/wconstab
ghstack dependencies: #124776, #124875
2024-05-01 07:56:58 +00:00