CUDA Graph Trees
Design doc: https://docs.google.com/document/d/1ZrxLGWz7T45MSX6gPsL6Ln4t0eZCSfWewtJ_qLd_D0E/edit
Not currently implemented :
- Right now, we are using weak tensor refs from outputs to check if a tensor has dies. This doesn't work because a) aliasing, and b) aot_autograd detaches tensors (see note [Detaching saved tensors in AOTAutograd]). Would need either https://github.com/pytorch/pytorch/issues/91395 to land to use storage weak refs or manually add a deleter fn that does what I want. This is doable but theres some interactions with the caching allocator checkpointing so saving for a stacked pr.
- Reclaiming memory from the inputs during model recording. This isn't terribly difficult but deferring to another PR. You would need to write over the input memory during warmup, and therefore copy the inputs to cpu. Saving for a stacked pr.
- Warning on overwriting previous generation outputs. and handling nested torch.compile() calls in generation tracking
Differential Revision: [D43999887](https://our.internmc.facebook.com/intern/diff/D43999887)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89146
Approved by: https://github.com/ezyang
In the prior patch, I just YOLOed a mutable mapping implementation.
Many edge cases were not handled correctly. In this PR, I just
copy paste the WeakKeyDictionary from CPython and the hacked it up
to use WeakIdRef instead of weakref.ref. You can see each line
I changed with the comment CHANGED; there aren't many.
Being exactly API compatible with WeakKeyDictionary means I can also
rob all of the tests from CPython, which I also did for
test/test_weak.py
How to review? You could either try taking the delta from CPython
(recommended), or review everything from scratch (not recommended).
Can post diff representing delta on request.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90825
Approved by: https://github.com/albanD