mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159062 Approved by: https://github.com/svekars, https://github.com/zou3519, https://github.com/anijain2305
142 lines
5.8 KiB
Markdown
142 lines
5.8 KiB
Markdown
# tlparse / TORCH_TRACE
|
||
|
||
tlparse / `TORCH_TRACE` are a pair of tools that produce compilation reports that look [like this](https://web.mit.edu/~ezyang/Public/bhack-20240609-tlparse/index.html).
|
||
|
||
Traces are fairly straightforward to collect. To collect a trace, run your model like so:
|
||
|
||
```bash
|
||
TORCH_TRACE="/tmp/tracedir" python foo.py
|
||
pip install tlparse
|
||
tlparse /tmp/tracedir
|
||
```
|
||
|
||
This approach works even if you are running a distributed job, providing a trace for each rank.
|
||
It will open your browser with HTML similar to what’s generated above.
|
||
If you are making a bug report for a complicated problem that you don’t have a standalone reproduction for,
|
||
you can still greatly assist PyTorch developers by attaching the trace log generated in `/tmp/tracedir`.
|
||
|
||
```{warning}
|
||
The trace log contains all of your model code.
|
||
Do not share the trace log if the model you are working on is sensitive. The trace log does NOT contain weights.
|
||
```
|
||
|
||
```{raw} html
|
||
<style>
|
||
.red {background-color:#ff0000;}
|
||
.green {background-color:#00ff00;}
|
||
.dark-green {background-color:#027f02;}
|
||
</style>
|
||
```
|
||
|
||
```{eval-rst}
|
||
.. role:: red
|
||
.. role:: green
|
||
.. role:: dark-green
|
||
```
|
||
|
||
The output of `tlparse` is primarily aimed for PyTorch developers,
|
||
and the log format is easy to upload and share on GitHub.
|
||
However, as a non-PyTorch developer, you can still extract useful information from it.
|
||
We recommend starting with the inline help text in the report, which explains its contents.
|
||
Here are some insights you can gain from a `tlparse`:
|
||
|
||
- What model code was compiled by looking at the stack trie?
|
||
This is especially useful if you're not familiar with the codebase being compiled!
|
||
- How many graph breaks / distinct compilation regions are there?
|
||
(Each distinct compile is its own color coded block like {dark-green}`[0/0]`).
|
||
Frames that are potentially graph-broken are light green {green}`[2/4]`.
|
||
If there are a lot of frames, that is suspicious, and suggests that you had some catastrophic graph breaks,
|
||
or maybe your code isn't a good match for `torch.compile`.
|
||
- How many times did I recompile a particular frame? Something that recompiled a lot will look like:
|
||
{dark-green}`[10/0]` {dark-green}`[10/1]` {dark-green}`[10/2]`
|
||
\- if something is being recompiled a lot, that is very suspicious and worth looking into, even if it isn't the root cause of your problem.
|
||
- Was there a compilation error? Frames that errored will look like {red}`[0/1]`.
|
||
- What intermediate compiler products did I generate for a given frame?
|
||
For example, you can look at the high-level generated FX graph or the generated Triton code.
|
||
- Is there relevant information for a particular frame? You can find these in `compilation_metrics`.
|
||
|
||
## TORCH_LOGS
|
||
|
||
You can use the `TORCH_LOGS` environment variable to selectively enable parts of the `torch.compile` stack to log.
|
||
`TORCH_LOGS` is in fact the source of logs for `tlparse`. The format of the `TORCH_LOGS` environment variable looks like this:
|
||
|
||
```bash
|
||
TORCH_LOGS="<option1>,<option2>,..." python foo.py
|
||
```
|
||
|
||
You can also programmatically set logging options using `torch._logging.set_logs`:
|
||
|
||
```python
|
||
import logging
|
||
torch._logging.set_logs(graph_breaks=True, dynamic=logging.DEBUG)
|
||
```
|
||
|
||
The most useful options are:
|
||
|
||
- `graph_breaks`: logs locations of graph breaks in user code and the reason for the graph break
|
||
- `guards`: logs guards that are generated
|
||
- `recompiles`: logs which function recompiled and the guards that failed, leading to the recompilation
|
||
- `dynamic`: logs related to dynamic shapes
|
||
- `output_code`: logs the code generated by Inductor
|
||
|
||
Some more helpful `TORCH_LOGS` options include:
|
||
|
||
```{eval-rst}
|
||
.. list-table::
|
||
:widths: 25 50
|
||
:header-rows: 1
|
||
|
||
* - Option
|
||
- Description
|
||
* - +all
|
||
- Output debug logs from all ``torch.compile`` components
|
||
* - +dynamo
|
||
- Output debug logs from TorchDynamo
|
||
* - +aot
|
||
- Output debug logs from AOTAutograd
|
||
* - +inductor
|
||
- Output debug logs from TorchInductor
|
||
* - dynamic
|
||
- Output logs from dynamic shapes
|
||
* - graph_code
|
||
- Output the Python code for the FX graph that Dynamo generated
|
||
* - graph_sizes
|
||
- Output the tensor sizes of the FX graph that Dynamo generated
|
||
* - trace_bytecode
|
||
- Output the bytecode instructions that Dynamo is tracing through and the symbolic interpreter stack Dynamo is keeping track of
|
||
* - trace_source
|
||
- Output the line of code in the original source that Dynamo is currently tracing through
|
||
* - bytecode
|
||
- Output Dynamo-generated bytecode
|
||
* - guards
|
||
- Output generated guards
|
||
* - recompiles
|
||
- Output recompilation reasons (only the first guard check that fails)
|
||
* - recompiles_verbose
|
||
- Output all guard checks that fail when a recompilation occurs
|
||
* - aot_graphs
|
||
- Output graph generated by AOTAutograd
|
||
* - aot_joint_graphs
|
||
- Output the joint forward-backward graph generated by AOTAutograd
|
||
* - output_code
|
||
- Output code generated by Inductor
|
||
* - kernel_code
|
||
- Output code generated by Inductor on a per-kernel basis
|
||
* - schedule
|
||
- Output Inductor scheduling logs
|
||
* - perf_hints
|
||
- Output Inductor perf hint logs
|
||
* - fusion
|
||
- Output Inductor fusion logs
|
||
```
|
||
|
||
For the full list of options, see [torch.\_logging](https://pytorch.org/docs/stable/logging.html)
|
||
and [torch.\_logging.set_logs](https://pytorch.org/docs/stable/generated/torch._logging.set_logs.html#torch._logging.set_logs).
|
||
|
||
## tlparse vs. TORCH_LOGS
|
||
|
||
Generally, we suggest first using `tlparse` when encountering issues.
|
||
`tlparse` is ideal for debugging large models and gaining a high-level overview of how your model was compiled.
|
||
On the other hand, `TORCH_LOGS` is preferred for small examples and fine-grained debugging detail,
|
||
when we already have an idea of which `torch.compile` component is causing the problem.
|