mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
## Before
Previously, CA will always unpack all saved variables stored in the autograd graph before executing it. This meant that we can't capture unpack hooks as part of the CA graph, and they would fire out of order wrt to other backward hooks. For memory saving APIs built on top of saved tensor hooks like non-reentrant checkpointing and offloading, we couldn't achieve any savings because all activations would be recomputed/loaded and active at the same time, resulting in no-op.
## After
We add unpack hooks into the CA graph so that they can be executed progressively. The python hook and hook input themselves are wrapped by non-traceable code, so CA polyfills the wrapping as:
```python
# pseudocode
class SavedVariable:
def unpack(self):
if self.hook:
return self.hook(self.packed_data)
else:
return self.packed_data
# This approach won't directly work when we add support for Forward AD or double-backward.
```
Directly executing the CA graph (without torch.compiling it) under checkpointing/offloading, memory profile is expected to stay the same as when using the eager autograd engine. If AOT backward is in the autograd graph, memory profile is expected to be better than the eager autograd engine, since we can now delay saved activations unpacking into the AOT backward's execution.
All tests pass when running the CA graph directly, the remaining issues are in Dynamo.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147242
Approved by: https://github.com/jansel
20 lines
547 B
C++
20 lines
547 B
C++
#pragma once
|
|
|
|
#include <ATen/core/Tensor.h>
|
|
#include <c10/core/SafePyObject.h>
|
|
|
|
namespace torch::autograd {
|
|
|
|
struct TORCH_API SavedVariableHooks {
|
|
virtual void call_pack_hook(const at::Tensor& tensor) = 0;
|
|
virtual at::Tensor call_unpack_hook() = 0;
|
|
virtual ~SavedVariableHooks() = default;
|
|
virtual std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
|
retrieve_unpack_hook_data() const {
|
|
throw std::runtime_error(
|
|
"Compiled Autograd only supports python saved tensor hooks ");
|
|
}
|
|
};
|
|
|
|
} // namespace torch::autograd
|