This word appears often in class descriptions and is not consistently spelled. Update comments and some function names to use the correct spelling consistently. Facilitates searching the codebase.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155944
Approved by: https://github.com/Skylion007
Adds create_graph support if you don't compile or compile only with torch.compile(backend="eager").
Using a backend that uses AOTDispatch produces a post-dispatch AOT backward, where its double backward will be silently incorrect if the forward trace involved any ops that are not composite implicit.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153222
Approved by: https://github.com/jansel
ghstack dependencies: #153193
Enables clang-tidy rule [`misc-use-internal-linkage`](https://clang.llvm.org/extra/clang-tidy/checks/misc/use-internal-linkage.html). This new check was introduced in Clang-Tidy 18 and is available due to recent update of Clang-Tidy 19.
The check marks functions and variables used only in the translation unit as static. Therefore undesired symbols are not leaked into other units, more link time optimisations are possible and the resulting binaries may be smaller.
The detected violations were mostly fixed by using static. In other cases, the symbols were indeed consumed by others files, then their declaring headers were included. Still some declarations were wrong and have been fixed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148948
Approved by: https://github.com/Skylion007
## Before
Previously, CA will always unpack all saved variables stored in the autograd graph before executing it. This meant that we can't capture unpack hooks as part of the CA graph, and they would fire out of order wrt to other backward hooks. For memory saving APIs built on top of saved tensor hooks like non-reentrant checkpointing and offloading, we couldn't achieve any savings because all activations would be recomputed/loaded and active at the same time, resulting in no-op.
## After
We add unpack hooks into the CA graph so that they can be executed progressively. The python hook and hook input themselves are wrapped by non-traceable code, so CA polyfills the wrapping as:
```python
# pseudocode
class SavedVariable:
def unpack(self):
if self.hook:
return self.hook(self.packed_data)
else:
return self.packed_data
# This approach won't directly work when we add support for Forward AD or double-backward.
```
Directly executing the CA graph (without torch.compiling it) under checkpointing/offloading, memory profile is expected to stay the same as when using the eager autograd engine. If AOT backward is in the autograd graph, memory profile is expected to be better than the eager autograd engine, since we can now delay saved activations unpacking into the AOT backward's execution.
All tests pass when running the CA graph directly, the remaining issues are in Dynamo.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147242
Approved by: https://github.com/jansel
## Before
Previously, CA will always unpack all saved variables stored in the autograd graph before executing it. This meant that we can't capture unpack hooks as part of the CA graph, and they would fire out of order wrt to other backward hooks. For memory saving APIs built on top of saved tensor hooks like non-reentrant checkpointing and offloading, we couldn't achieve any savings because all activations would be recomputed/loaded and active at the same time, resulting in no-op.
## After
We add unpack hooks into the CA graph so that they can be executed progressively. The python hook and hook input themselves are wrapped by non-traceable code, so CA polyfills the wrapping as:
```python
# pseudocode
class SavedVariable:
def unpack(self):
if self.hook:
return self.hook(self.packed_data)
else:
return self.packed_data
# This approach won't directly work when we add support for Forward AD or double-backward.
```
Directly executing the CA graph (without torch.compiling it) under checkpointing/offloading, memory profile is expected to stay the same as when using the eager autograd engine. If AOT backward is in the autograd graph, memory profile is expected to be better than the eager autograd engine, since we can now delay saved activations unpacking into the AOT backward's execution.
All tests pass when running the CA graph directly, the remaining issues are in Dynamo.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147242
Approved by: https://github.com/jansel
This PR is on the way to getting compiled autograd's initial capture to
stop specializing on Tensor metadata.
This PR changes compiled autograd's initial capture to proxy an opaque
(w.r.t. Dynamo) function into the graph for all built-in codegen'ed
autograd nodes and validate_outputs.
We changed each codegen'ed apply_with_saved (e.g.
MulBackward0::apply_with_saved) to call into Python to proxy a function
(compiled_autograd.ops.MulBackward0) into the graph. Then, we use the
node's InputMetadata to "guess" at the properties of the output Tensors
to create some new FakeTensors.
Some details:
- MulBackward0::apply_with_saved lives in libtorch_cpu, but needs to be
call to Python via libtorch_python. There is an indirection
(PyCompilerInterface) to do this.
- MulBackward0::apply_with_saved passes a C++ function to Python. To make
our lives easier, every codegen'ed apply_with_saved passes a C++
function with the same signature
`(variable_list, ivalue_list) -> variable_list`.
- We define how to pack arbitrary C++ types into IValue via a helper
IValuePacker struct and codegen functional variants of each builtin
C++ autograd node (e.g. MulBackward0_apply_functional_ivalue).
MulBackward0 before this PR:
https://gist.github.com/zou3519/a80381d5fa38e970e413fcd91b0530de
MulBackward0 after this PR:
https://gist.github.com/zou3519/0c2eee8b3d8d96232b51ef430b53c5b0
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143296
Approved by: https://github.com/jansel
Static variables in C++11 is guaranteed to be initialised exactly once, as mentioned [here](https://en.cppreference.com/w/cpp/language/storage_duration)
```
If multiple threads attempt to initialize the same static local variable concurrently,
the initialization occurs exactly once
(similar behavior can be obtained for arbitrary functions with std::call_once.
Usual implementations of this feature use variants
of the double-checked locking pattern,
which reduces runtime overhead for already-initialized local statics
to a single non-atomic boolean comparison.
```
Given that static c10::once_flag is used before, why not just use the associated function to initialised the related static variables? That is the motivation behind this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143255
Approved by: https://github.com/albanD
GraphTask holds metadata needed for a single execution of backward(), it is 1:1 with backward calls, at least for compiled autograd. It is used for certain torch._C global autograd state APIs.
In SAC, we use torch._C._current_graph_task_id() as a dict key to store information during unpack hook execution: a5fb07af27/torch/utils/checkpoint.py (L1128)
If we don't set an active task, it will randomize the key, and will do its logic as if each unpacked tensor was from a different graph task
a5fb07af27/torch/utils/checkpoint.py (L1112-L1115)
The sketchy part of this PR is that in eager autograd, GraphTask is mutated during execution. But inspecting the struct, the mutation seems to only be used to communicate between autograd threads (created when multiple devices are involved) or for deprecated uses. We shouldn't run into the mutation case at all in compiled autograd. Also, only the graph task id is accessible from python hooks.
FIXES https://github.com/pytorch/pytorch/issues/142862
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143108
Approved by: https://github.com/jansel, https://github.com/albanD
Thanks @awgu for raising this issue and the small repro
From offline discussion with @albanD, in the case where a forward returns multiple outputs with different devices, we'd want to select the ready queue based on the device of the first one. Even though this is somewhat arbitrary, we prefer this over deciding which ready queue to push based on whichever input buffer's we happen to compute last, which can vary depending on more factors and thus be harder to reason about. This is in theory bc-breaking, but it seems unlikely that someone would depend on this behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135633
Approved by: https://github.com/albanD
Summary:
`-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it.
This:
```
try {
...
} catch (exception& e) {
// no use of e
}
```
should instead be written as
```
} catch (exception&) {
```
If the code compiles, this is safe to land.
Test Plan: Sandcastle
Reviewed By: palmje
Differential Revision: D55548497
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123056
Approved by: https://github.com/Skylion007
The main thread and the autograd one are latency critical threads. They launch CPU/GPU/Accelerator kernels and if for some reason they get preempted, the rank can become a straggler in a distributed training application. By naming these threads we can debug performance issues that impact the latency sensitive threads.
I used Kineto traces to verify if the thread names were propagated:
<img width="851" alt="Screenshot 2024-03-04 at 3 07 43 PM" src="https://github.com/pytorch/pytorch/assets/23515689/68b4a09c-b8e5-4f14-a5c0-6593f866c03f">
Also:
```
nvidia-smi
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 3065920 C ...me#python#py_version_3_10 1968MiB |
| 1 N/A N/A 3065926 C ...me#python#py_version_3_10 1978MiB |
| 2 N/A N/A 3065930 C ...me#python#py_version_3_10 2084MiB |
| 3 N/A N/A 3065936 C ...me#python#py_version_3_10 2016MiB |
| 4 N/A N/A 3065939 C ...me#python#py_version_3_10 1998MiB |
| 5 N/A N/A 3065943 C ...me#python#py_version_3_10 2070MiB |
| 6 N/A N/A 3065948 C ...me#python#py_version_3_10 2026MiB |
| 7 N/A N/A 3065952 C ...me#python#py_version_3_10 2070MiB |
+-----------------------------------------------------------------------------+
[me@myhost ~]$ ps -T -p 3065920
PID SPID TTY TIME CMD
3065920 3065920 pts/14 00:01:04 pt_main_thread
...
3065920 3092181 pts/14 00:00:40 pt_autograd_d0
3065920 3092182 pts/14 00:00:00 pt_autograd_d1
3065920 3092183 pts/14 00:00:00 pt_autograd_d2
3065920 3092184 pts/14 00:00:00 pt_autograd_d3
3065920 3092185 pts/14 00:00:00 pt_autograd_d4
3065920 3092186 pts/14 00:00:00 pt_autograd_d5
3065920 3092187 pts/14 00:00:00 pt_autograd_d6
3065920 3092188 pts/14 00:00:00 pt_autograd_d7
...
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121170
Approved by: https://github.com/albanD