mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Convert compiler rst files to markdown (#155335)
Convert following compiler rst files to md file. torch.compiler_inductor_profiling.rst torch.compiler_ir.rst torch.compiler_nn_module.rst torch.compiler_performance_dashboard.rst torch.compiler_profiling_torch_compile.rst Fixes #155039 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155335 Approved by: https://github.com/svekars Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
This commit is contained in:
parent
1851f50866
commit
f34335bf33
|
|
@ -1,7 +1,4 @@
|
|||
.. _torchinductor-gpu-profiling:
|
||||
|
||||
TorchInductor GPU Profiling
|
||||
===========================
|
||||
# TorchInductor GPU Profiling
|
||||
|
||||
This section lists useful commands and workflows that can help
|
||||
you dive into a model’s performance in TorchInductor. When a model is not
|
||||
|
|
@ -11,8 +8,7 @@ GPU time are the most interesting ones. After that, you
|
|||
may also want to run individual kernels directly and inspect its perf.
|
||||
PyTorch provides tools to cover everything mentioned above.
|
||||
|
||||
Relevant Environment Variables
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
## Relevant Environment Variables
|
||||
|
||||
You can use the following environment variables in your analysis:
|
||||
|
||||
|
|
@ -37,29 +33,28 @@ You can use the following environment variables in your analysis:
|
|||
one with the best performance results. This will increase compilation
|
||||
time with the hope to improve performance.
|
||||
|
||||
Breakdown Model GPU Time
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
## Breakdown Model GPU Time
|
||||
|
||||
Below are the steps to breakdown execution time of a model into
|
||||
individual kernels. We take ``mixnet_l`` as an example.
|
||||
|
||||
1. Run the benchmark script for the model:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
```bash
|
||||
TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCHINDUCTOR_BENCHMARK_KERNEL=1
|
||||
python -u benchmarks/dynamo/timm_models.py –backend inductor –amp
|
||||
–performance –dashboard –only mixnet_l –disable-cudagraphs –training
|
||||
|
||||
.. note:: The tool relies on kernel name to decide its category. Enabling
|
||||
```
|
||||
```{note}
|
||||
The tool relies on kernel name to decide its category. Enabling
|
||||
``TORCHINDUCTOR_UNIQUE_KERNEL_NAMES`` is crucial for that.
|
||||
|
||||
```
|
||||
2. In the output log, look for lines:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
```bash
|
||||
**Compiled module path:
|
||||
/tmp/torchinductor_shunting/qz/cqz7hvhood7y3psp7fy6msjxsxyli7qiwiybizdwtjw6ffyq5wwd.py**
|
||||
```
|
||||
|
||||
We have one line for each compiled module. If there are no extra graph
|
||||
breaks, we would see 2 such lines in the log, one for the forward graph
|
||||
|
|
@ -68,34 +63,33 @@ and one for the backward graph.
|
|||
For our example command, we get the following compiled module for the
|
||||
forward and backward graphs respectively:
|
||||
|
||||
- https://gist.github.com/shunting314/c2a4d8a28b00fcb5586d0e9d9bf77f9f
|
||||
- https://gist.github.com/shunting314/48efc83b12ec3ead950052e4a0220b10
|
||||
- [Forward graph compiled module](https://gist.github.com/shunting314/c2a4d8a28b00fcb5586d0e9d9bf77f9f)
|
||||
- [Backward graph compiled module](https://gist.github.com/shunting314/48efc83b12ec3ead950052e4a0220b10)
|
||||
|
||||
3. Now we can dive into the perf for each individual compiled module.
|
||||
Let’s pick the one for the forward graph for illustration purposes.
|
||||
I’ll name it ``fwd.py`` for convenience. Run it directly with the
|
||||
``-p`` argument:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
```bash
|
||||
**> python fwd.py -p**
|
||||
```
|
||||
|
||||
See the full output log in this
|
||||
`example gist <https://gist.github.com/shunting314/8243734a38b5733ea78479209c0ae893>`__.
|
||||
See the full output log in this [example gist](https://gist.github.com/shunting314/8243734a38b5733ea78479209c0ae893)
|
||||
|
||||
In the output, you can notice the following:
|
||||
|
||||
* We write a chrome trace file for the profile so we can load the trace and interact with it. In the log, look for lines as follows to find the path of the trace file.
|
||||
|
||||
**Chrome trace for the profile is written to
|
||||
/tmp/compiled_module_profile.json**
|
||||
|
||||
Loading the trace into Chrome (visit chrome://tracing in the chrome
|
||||
browser and load the file as the UI suggested) will show UI as follows:
|
||||
**Chrome trace for the profile is written to /tmp/compiled_module_profile.json**
|
||||
|
||||
.. image:: _static/img/inductor_profiling/trace.png
|
||||
Loading the trace into Chrome (visit chrome://tracing in the chrome browser and load the file as the UI suggested) will show UI as follows:
|
||||
|
||||
You can zoom in and out to check the profile.
|
||||
```{image} _static/img/inductor_profiling/trace.png
|
||||
```
|
||||
|
||||
You can zoom in and out to check the profile.
|
||||
|
||||
* We report the percent of GPU time regarding to the wall time by log line like:
|
||||
|
||||
|
|
@ -109,9 +103,9 @@ In the output, you can notice the following:
|
|||
If we run the model like ``densenet121`` with a small batch size, we would see
|
||||
low percent of time when GPU is busy:
|
||||
|
||||
::
|
||||
|
||||
```bash
|
||||
(Forward graph) Percent of time when GPU is busy: 32.69%
|
||||
```
|
||||
|
||||
This means the model has a lot of CPU overhead. This is consistent with
|
||||
the fact that enabling cudagraphs improve densenet121’s perf a lot.
|
||||
|
|
@ -130,7 +124,8 @@ In the output, you can notice the following:
|
|||
* We also call zoom into a certain category of kernels. For example,
|
||||
let’s check reduction kernels:
|
||||
|
||||
.. image:: _static/img/inductor_profiling/kernel_breakdown.png
|
||||
```{image} _static/img/inductor_profiling/kernel_breakdown.png
|
||||
```
|
||||
|
||||
We can see an ordered table of execution time for each individual
|
||||
reduction kernel. We also see how many times a kernel is executed. This
|
||||
|
|
@ -142,8 +137,7 @@ In the output, you can notice the following:
|
|||
- Ff a kernel takes 2% of time, improving it by 2x will bring in 1%
|
||||
overall gain which justifies the effort.
|
||||
|
||||
Benchmark Individual Triton Kernel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
## Benchmark Individual Triton Kernel
|
||||
|
||||
Let’s say we want to take a closer look at
|
||||
``triton_red_fused\__native_batch_norm_legit_functional_16`` which is the
|
||||
|
|
@ -155,23 +149,23 @@ We can lookup the kernel name in the ``fwd.py``, and find comment like:
|
|||
**# kernel path:
|
||||
/tmp/torchinductor_shunting/jk/cjk2vm3446xrk7rth7hr6pun7xxo3dnzubwcn6ydrpifal4eykrz.py**
|
||||
|
||||
.. image:: _static/img/inductor_profiling/inductor_code.png
|
||||
```{image} _static/img/inductor_profiling/inductor_code.png
|
||||
```
|
||||
|
||||
I’ll rename it k.py for convenience. Here is a paste for this
|
||||
`file <https://gist.github.com/shunting314/96a0afef9dce53d6357bf1633094f358>`__.
|
||||
I’ll rename it k.py for convenience. Here is a paste for this [file](https://gist.github.com/shunting314/96a0afef9dce53d6357bf1633094f358).
|
||||
|
||||
``k.py`` is a standalone Python module containing the kernel code and its
|
||||
benchmark.
|
||||
|
||||
Run ``k.py`` directly will report its execution time and bandwidth:
|
||||
|
||||
.. image:: _static/img/inductor_profiling/terminal_printout.png
|
||||
```{image} _static/img/inductor_profiling/terminal_printout.png
|
||||
```
|
||||
|
||||
We can check if max-autotune helps this kernel, by running:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
```bash
|
||||
**TORCHINDUCTOR_MAX_AUTOTUNE=1 python /tmp/k.py**
|
||||
|
||||
```
|
||||
We may also temporarily add more reduction heuristics and run the script
|
||||
again to check how that helps with the kernel.
|
||||
|
|
@ -1,12 +1,8 @@
|
|||
.. _torch.compiler_ir:
|
||||
|
||||
IRs
|
||||
===============
|
||||
# IRs
|
||||
|
||||
PyTorch 2.0 offers two set of IRs for backends to interface with: Core Aten IR and Prims IR.
|
||||
|
||||
Core Aten IR
|
||||
--------------------
|
||||
## Core Aten IR
|
||||
|
||||
Core aten ops is the core subset of aten operators that can be used to compose other operators.
|
||||
Core aten IR is fully functional, and there is no `inplace` or `_out` variants in this opset.
|
||||
|
|
@ -14,26 +10,29 @@ In contrast to Prims IR, core aten ops reuses the existing aten ops in "native_f
|
|||
and it doesn't further decompose ops into explicit type promotion and broadcasting ops.
|
||||
This opset is designed to serve as the functional IR to interface with backends.
|
||||
|
||||
.. warning::
|
||||
```{warning}
|
||||
This opset is still under active development, more ops will be added in the future.
|
||||
```
|
||||
|
||||
.. csv-table::
|
||||
```{csv-table}
|
||||
:file: ../build/ir/aten_ops.csv
|
||||
:widths: auto
|
||||
:header-rows: 1
|
||||
```
|
||||
|
||||
Prims IR
|
||||
-----------
|
||||
## Prims IR
|
||||
|
||||
Prims IR is a set of primitive operators that can be used to compose other operators.
|
||||
Prims IR is a lower level opset than core aten IR, and it further decomposes ops into explicit
|
||||
type promotion and broadcasting ops: prims.convert_element_type and prims.broadcast_in_dim.
|
||||
This opset is designed to interface with compiler backends.
|
||||
|
||||
.. warning::
|
||||
```{warning}
|
||||
This opset is still under active development, more ops will be added in the future.
|
||||
```
|
||||
|
||||
.. csv-table::
|
||||
```{csv-table}
|
||||
:file: ../build/ir/prims_ops.csv
|
||||
:widths: auto
|
||||
:header-rows: 1
|
||||
```
|
||||
|
|
@ -1,15 +1,14 @@
|
|||
PyTorch 2.0 NNModule Support
|
||||
============================
|
||||
# PyTorch 2.0 NNModule Support
|
||||
|
||||
**Author**: `Will Constable <https://github.com/wconstab>`_
|
||||
**Author**: [Will Constable](https://github.com/wconstab)
|
||||
|
||||
`torch.compile` has special handling for torch.nn.Module objects, tracing them differently than it traces
|
||||
arbitrary python classes, with the intent of producing faster code by making assumptions about the structure.
|
||||
|
||||
This doc describes some of the tradeoffs or edge cases that come up due to this specialization.
|
||||
|
||||
NNModule Hooks Support
|
||||
----------------------
|
||||
## NNModule Hooks Support
|
||||
|
||||
Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered
|
||||
they would simply be ignored in the compiled program. Indeed many users do not
|
||||
use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases
|
||||
|
|
@ -22,8 +21,8 @@ These hooks are partially supported by `torch.compile` with limitations describe
|
|||
Another category of hooks includes `_state_dict_hooks` and its `pre` and `load_` variants, and are still
|
||||
unsupported by `torch.compile`.
|
||||
|
||||
`nn.Module.__call__` Hooks Usage and limitations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
## `nn.Module.__call__` Hooks Usage and limitations
|
||||
|
||||
By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter
|
||||
and run forward/pre-forward hooks. If you install hooks before calling `torch.compile` and then do not remove
|
||||
or alter the hooks later, your use case should be supported by default.
|
||||
|
|
@ -52,8 +51,8 @@ guards.
|
|||
|
||||
TODO: confirm if backward/pre_backward hooks are working or not and document accordingly
|
||||
|
||||
state_dict Hooks
|
||||
~~~~~~~~~~~~~~~~
|
||||
## state_dict Hooks
|
||||
|
||||
State dict hooks have not yet been supported in `torch.compile`.
|
||||
|
||||
|
||||
|
|
@ -1,15 +1,13 @@
|
|||
PyTorch 2.0 Performance Dashboard
|
||||
=================================
|
||||
# PyTorch 2.0 Performance Dashboard
|
||||
|
||||
**Author:** `Bin Bao <https://github.com/desertfire>`__ and `Huy Do <https://github.com/huydhn>`__
|
||||
**Author:** [Bin Bao](https://github.com/desertfire) and [Huy Do](https://github.com/huydhn)
|
||||
|
||||
PyTorch 2.0's performance is tracked nightly on this `dashboard <https://hud.pytorch.org/benchmark/compilers>`__.
|
||||
PyTorch 2.0's performance is tracked nightly on this [dashboard](https://hud.pytorch.org/benchmark/compilers).
|
||||
The performance collection runs on 12 GCP A100 nodes every night. Each node contains a 40GB A100 Nvidia GPU and
|
||||
a 6-core 2.2GHz Intel Xeon CPU. The corresponding CI workflow file can be found
|
||||
`here <https://github.com/pytorch/pytorch/blob/main/.github/workflows/inductor-perf-test-nightly.yml>`__.
|
||||
[here](https://github.com/pytorch/pytorch/blob/main/.github/workflows/inductor-perf-test-nightly.yml).
|
||||
|
||||
How to read the dashboard?
|
||||
---------------------------
|
||||
## How to read the dashboard?
|
||||
|
||||
The landing page shows tables for all three benchmark suites we measure, ``TorchBench``, ``Huggingface``, and ``TIMM``,
|
||||
and graphs for one benchmark suite with the default setting. For example, the default graphs currently show the AMP
|
||||
|
|
@ -21,32 +19,29 @@ Both ``Geometric mean speedup`` and ``Peak memory footprint compression ratio``
|
|||
the PyTorch eager performance, and the larger the better. Each individual performance number on those tables can be clicked,
|
||||
which will bring you to a view with detailed numbers for all the tests in that specific benchmark suite.
|
||||
|
||||
What is measured on the dashboard?
|
||||
-----------------------------------
|
||||
## What is measured on the dashboard?
|
||||
|
||||
All the dashboard tests are defined in this
|
||||
`function <https://github.com/pytorch/pytorch/blob/3e18d3958be3dfcc36d3ef3c481f064f98ebeaf6/.ci/pytorch/test.sh#L305>`__.
|
||||
[function](https://github.com/pytorch/pytorch/blob/3e18d3958be3dfcc36d3ef3c481f064f98ebeaf6/.ci/pytorch/test.sh#L305).
|
||||
The exact test configurations are subject to change, but at the moment, we measure both inference and training
|
||||
performance with AMP precision on the three benchmark suites. We also measure different settings of TorchInductor,
|
||||
including ``default``, ``with_cudagraphs (default + cudagraphs)``, and ``dynamic (default + dynamic_shapes)``.
|
||||
|
||||
Can I check if my PR affects TorchInductor's performance on the dashboard before merging?
|
||||
-----------------------------------------------------------------------------------------
|
||||
## Can I check if my PR affects TorchInductor's performance on the dashboard before merging?
|
||||
|
||||
Individual dashboard runs can be triggered manually by clicking the ``Run workflow`` button
|
||||
`here <https://github.com/pytorch/pytorch/actions/workflows/inductor-perf-test-nightly.yml>`__
|
||||
[here](https://github.com/pytorch/pytorch/actions/workflows/inductor-perf-test-nightly.yml)
|
||||
and submitting with your PR's branch selected. This will kick off a whole dashboard run with your PR's changes.
|
||||
Once it is done, you can check the results by selecting the corresponding branch name and commit ID
|
||||
on the performance dashboard UI. Be aware that this is an expensive CI run. With the limited
|
||||
resources, please use this functionality wisely.
|
||||
|
||||
How can I run any performance test locally?
|
||||
--------------------------------------------
|
||||
## How can I run any performance test locally?
|
||||
|
||||
The exact command lines used during a complete dashboard run can be found in any recent CI run logs.
|
||||
The `workflow page <https://github.com/pytorch/pytorch/actions/workflows/inductor-perf-test-nightly.yml>`__
|
||||
The [workflow page](https://github.com/pytorch/pytorch/actions/workflows/inductor-perf-test-nightly.yml)
|
||||
is a good place to look for logs from some of the recent runs.
|
||||
In those logs, you can search for lines like
|
||||
``python benchmarks/dynamo/huggingface.py --performance --cold-start-latency --inference --amp --backend inductor --disable-cudagraphs --device cuda``
|
||||
`python benchmarks/dynamo/huggingface.py --performance --cold-start-latency --inference --amp --backend inductor --disable-cudagraphs --device cuda`
|
||||
and run them locally if you have a GPU working with PyTorch 2.0.
|
||||
``python benchmarks/dynamo/huggingface.py -h`` will give you a detailed explanation on options of the benchmarking script.
|
||||
|
|
@ -1,25 +1,22 @@
|
|||
Profiling to understand torch.compile performance
|
||||
=================================================
|
||||
# Profiling to understand torch.compile performance
|
||||
|
||||
What to use torch.profiler for:
|
||||
-------------------------------
|
||||
## What to use torch.profiler for:
|
||||
|
||||
torch.profiler is helpful for understanding the performance of your program at a kernel-level granularity - for example, it can show graph breaks and resources utilization at the level of the program. The data provided by the profiler can often help users understand where to investigate further to understand model performance.
|
||||
|
||||
To understand kernel-level performance, other tools exist, such as `Nvidia Nsight compute tool <https://developer.nvidia.com/nsight-compute>`_, `AMD Omnitrace <https://rocm.docs.amd.com/projects/omnitrace/en/latest/>`_, Intel® VTune™ Profiler or :ref:`inductor's profiling tools <torchinductor-gpu-profiling>` can be used.
|
||||
To understand kernel-level performance, other tools exist, such as [Nvidia Nsight compute tool](https://developer.nvidia.com/nsight-compute), [AMD Omnitrace](https://rocm.docs.amd.com/projects/omnitrace/en/latest/), Intel® VTune™ Profiler or [inductor's profiling tools](https://docs.pytorch.org/docs/stable/torch.compiler_inductor_profiling.html#torchinductor-gpu-profiling) can be used.
|
||||
|
||||
See also the `general pytorch profiler guide <https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html>`_.
|
||||
See also the [general pytorch profiler guide](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html).
|
||||
|
||||
Basics of using torch.profiler and viewing traces
|
||||
-------------------------------------------------
|
||||
## Basics of using torch.profiler and viewing traces
|
||||
|
||||
**Example program**: We'll use this example of profiling resnet18. Notice the following parts of this example program:
|
||||
|
||||
* Include a warm-up run to wait for compilation to complete (this will warm up systems like the CUDA caching allocator)
|
||||
* Use :code:`torch.profiler.profile()` context for profiling the section we are interested in
|
||||
* Use :code:`prof.export_chrome_trace("trace.json")` to export the profiling artifact.
|
||||
* Use `torch.profiler.profile()` context for profiling the section we are interested in
|
||||
* Use `prof.export_chrome_trace("trace.json")` to export the profiling artifact.
|
||||
|
||||
.. code-block:: python
|
||||
```python
|
||||
|
||||
import torch
|
||||
from torchvision.models import resnet18
|
||||
|
|
@ -44,11 +41,13 @@ Basics of using torch.profiler and viewing traces
|
|||
prof.step()
|
||||
|
||||
prof.export_chrome_trace("trace.json")
|
||||
```
|
||||
|
||||
**Viewing chrome traces**: In the Chrome browser, open chrome://tracing and load the json file. Use the “w” and “s” keys to zoom in and out, and use “a” and “d” to scroll left and right. “?” will show a “help” screen with a list of shortcuts.
|
||||
|
||||
.. figure:: _static/img/profiling_torch_compile/basic_chrome_trace.png
|
||||
:alt: Example of a basic chrome trace, visualized in the chrome://tracing viewer
|
||||
```{figure} _static/img/profiling_torch_compile/basic_chrome_trace.png
|
||||
:alt: Example of a basic chrome trace, visualized in the chrome://tracing viewer
|
||||
```
|
||||
|
||||
Here, we observe:
|
||||
* CompiledFunction and CompiledFunctionBackward events, which correspond to the dynamo-compiled regions.
|
||||
|
|
@ -60,26 +59,26 @@ Every kernel on the accelerator occurs after being launched by code running on t
|
|||
|
||||
To view a flow connection, click on a GPU kernel and click “ac2g”:
|
||||
|
||||
.. figure:: _static/img/profiling_torch_compile/ac2g.png
|
||||
:alt: Visualization in the chrome://trace viewer, showing an async flow between a kernel and its launching location.
|
||||
```{figure} _static/img/profiling_torch_compile/ac2g.png
|
||||
:alt: Visualization in the chrome://trace viewer, showing an async flow between a kernel and its launching location.
|
||||
```
|
||||
|
||||
Alternatively, turn on *all* flows with the “Flow events” dropdown at the top.
|
||||
|
||||
Working around CUDA Graph profiling issues
|
||||
------------------------------------------
|
||||
## Working around CUDA Graph profiling issues
|
||||
|
||||
When CUDA graphs are enabled, some CUDA configurations (driver version under 525.85.12 or CUDA < 12) can encounter issues between the profiling tools and CUDA graphs. To fix these issues, add an empty profiling context at the top of your program:
|
||||
|
||||
.. code-block:: python
|
||||
```python
|
||||
|
||||
import torch
|
||||
|
||||
torch.profiler._utils._init_for_cuda_graphs()
|
||||
|
||||
# ... rest of program
|
||||
```
|
||||
|
||||
Understanding compilation time
|
||||
------------------------------
|
||||
## Understanding compilation time
|
||||
|
||||
To understand why compilation is taking a long time, you can profile the first invocation of a torch.compile-ed program. Keep in mind that profile traces of compilations can be distorted more than typical profiling, because compilation workloads can be quite different from typical PyTorch workloads. In some cases, trace files may also be quite large. Traces > 1GB can be difficult to open with the chrome tracing tool.
|
||||
|
||||
|
|
@ -87,7 +86,7 @@ Note: roughly the same information can also be obtained in non-graphical format
|
|||
|
||||
See an example below:
|
||||
|
||||
.. code-block:: python
|
||||
```python
|
||||
|
||||
import torch
|
||||
from torchvision.models import resnet18
|
||||
|
|
@ -120,18 +119,18 @@ See an example below:
|
|||
fwd_bwd(inputs[0])
|
||||
|
||||
prof.export_chrome_trace("trace_compile.json")
|
||||
```
|
||||
|
||||
.. figure:: _static/img/profiling_torch_compile/compilation_profiling.png
|
||||
:alt: A visualization in the chrome://trace viewer, showing dynamo and inductor compilation steps
|
||||
|
||||
```{figure} _static/img/profiling_torch_compile/compilation_profiling.png
|
||||
:alt: A visualization in the chrome://trace viewer, showing dynamo and inductor compilation steps
|
||||
```
|
||||
|
||||
Note a few things:
|
||||
|
||||
* The first invocation should occur *during* profiling in order to capture compilation
|
||||
* Add a warm-up compilation in order to initialize any systems that need to be lazily initialized.
|
||||
|
||||
Finding graph breaks: "Torch-Compiled Region" and "CompiledFunction"
|
||||
--------------------------------------------------------------------
|
||||
# Finding graph breaks: "Torch-Compiled Region" and "CompiledFunction"
|
||||
|
||||
Although there are logging tools for identifying graph breaks, the profiler provides a quick visual method of identifying :ref:`graph breaks <torch.compiler_graph_breaks>`. There are two profiler events to look for: **Torch-Compiled Region** and **CompiledFunction**.
|
||||
|
||||
|
|
@ -147,7 +146,7 @@ If your use case includes a graph that doesn't require grad and doesn't include
|
|||
|
||||
See the synthetic example below for a demonstration:
|
||||
|
||||
.. code-block:: python
|
||||
```python
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
|
@ -197,12 +196,13 @@ See the synthetic example below for a demonstration:
|
|||
prof.step()
|
||||
|
||||
prof.export_chrome_trace("trace_break.json")
|
||||
```
|
||||
|
||||
.. figure:: _static/img/profiling_torch_compile/graph_breaks_with_torch_compiled_region.png
|
||||
:alt: Visualization in the chrome://trace viewer, showing nested Torch-Compiled Region events and multiple CompiledFunction events - indicating graph breaks.
|
||||
```{figure} _static/img/profiling_torch_compile/graph_breaks_with_torch_compiled_region.png
|
||||
:alt: Visualization in the chrome://trace viewer, showing nested Torch-Compiled Region events and multiple CompiledFunction events - indicating graph breaks.
|
||||
```
|
||||
|
||||
Operator Kernels
|
||||
----------------
|
||||
## Operator Kernels
|
||||
|
||||
When an operator is launched, we expect to see a few events:
|
||||
|
||||
|
|
@ -210,15 +210,17 @@ When an operator is launched, we expect to see a few events:
|
|||
2. Kernel launch (if dealing with a GPU kernel)
|
||||
3. GPU-side event
|
||||
|
||||
.. figure:: _static/img/profiling_torch_compile/kernel_launch_labeled.png
|
||||
:alt: Visualization in the chrome://trace viewer, showing the three types of events: CPU-side event, kernel launch, and GPU-side event
|
||||
```{figure} _static/img/profiling_torch_compile/kernel_launch_labeled.png
|
||||
:alt: Visualization in the chrome://trace viewer, showing the three types of events - CPU-side event, kernel launch, and GPU-side event
|
||||
```
|
||||
|
||||
**Inductor-generated Triton kernels:**
|
||||
1. The **CPU-side event** should appear as an event prefixed with "triton\_". The events currently have minimal information - the kernel name and a launch, but less information than typical aten kernel launches (which contain input shapes, types, etc.).
|
||||
2. The **kernel launch** should appear as cuLaunchKernel instead of cudaLaunchKernel (cudaLaunchKernel is typical for aten ops)
|
||||
3. The **GPU-side event** should appear, and how descriptive the name will be depends on the inductor config for unique_kernel_names
|
||||
|
||||
.. figure:: _static/img/profiling_torch_compile/triton_kernel_launch.png
|
||||
```{figure} _static/img/profiling_torch_compile/triton_kernel_launch.png
|
||||
```
|
||||
|
||||
**Non-Inductor generated Triton kernels:**
|
||||
|
||||
|
|
@ -226,7 +228,8 @@ When an operator is launched, we expect to see a few events:
|
|||
2. The **kernel launch** should appear s cuLaunchKernel instead of cudaLaunchKernel (cudaLaunchKernel is typical for aten ops)
|
||||
3. The **GPU-side** event should appear, named similarly to the triton kernel that was authored.
|
||||
|
||||
.. figure:: _static/img/profiling_torch_compile/noninductor_triton_kernel.png
|
||||
```{figure} _static/img/profiling_torch_compile/noninductor_triton_kernel.png
|
||||
```
|
||||
|
||||
**Inductor-generated CPU kernels:**
|
||||
|
||||
|
|
@ -236,14 +239,14 @@ When an operator is launched, we expect to see a few events:
|
|||
**Non-Triton kernels** (i.e. aten kernels or custom ops) should also be expected to sometimes appear in traces. Sometimes, Inductor will fall back to the original op implementation, in which case you will see a call to the aten op.
|
||||
|
||||
|
||||
Launch overhead
|
||||
---------------
|
||||
## Launch overhead
|
||||
|
||||
One common issue is bad GPU utilization. A quick way to identify this is if there are large gaps between kernels on the GPU:
|
||||
|
||||
.. figure:: _static/img/profiling_torch_compile/cpu_bound.png
|
||||
:alt: Visualization in the chrome://trace viewer, showing large gaps between GPU kernels. This indicates that the model is CPU bound, likely due to overhead during kernel launches.
|
||||
```{figure} _static/img/profiling_torch_compile/cpu_bound.png
|
||||
:alt: Visualization in the chrome://trace viewer, showing large gaps between GPU kernels. This indicates that the model is CPU bound, likely due to overhead during kernel launches.
|
||||
```
|
||||
|
||||
This is often the result of CPU overhead, e.g. if the amount of time spent on the CPU between kernel launches is larger than the amount of time spent by the GPU to process the kernels. The issue is more common for small batch sizes.
|
||||
|
||||
When using inductor, enabling CUDA graphs can often help improve performance when launch overhead is a concern.
|
||||
When using inductor, enabling CUDA graphs can often help improve performance when launch overhead is a concern.
|
||||
Loading…
Reference in New Issue
Block a user