[Profiler] Update README (#159816)

Summary: Updated README with code structure and explanation of core features within profiler

Test Plan:
N/A

Rollback Plan:

Differential Revision: D79604189

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159816
Approved by: https://github.com/sanrise, https://github.com/aaronenyeshi
This commit is contained in:
Shivam Raikundalia 2025-08-07 16:44:41 +00:00 committed by PyTorch MergeBot
parent e1cf0d496e
commit b1a602762e

View File

@ -13,14 +13,49 @@ The profiler instruments PyTorch to collect information about the model's execut
- [Codebase Structure](#codebase-structure)
- [`RecordFunction`](#recordfunction)
- [Autograd Integration](#autograd-integration)
- [Collection and Post-Processing](#collection-and-post-processing)
- [Torch Operation Collection](#torch-operation-collection)
- [Allocation Event Collection](#allocation-event-collection)
- [Kineto Integration](#kineto-integration)
- [Python Tracing](#python-tracing)
- [Clock Alignment](#clock-alignment)
## Codebase Structure ##
TODO
This section highlights directories an files that are significant to the profiler. Lesser relevant files, directories, and modules are omitted.
```
torch/
├── profiler/ # Main package containing the core frontend logic
│ ├── __init__.py # Initialization file for profiler package
│ ├── profiler.py # Main profiler frontend class
│ └── _utils.py # FunctionEvent utils
├── autograd/ # Autograd package
│ ├── __init__.py # Initialization file for autograd package
│ ├── profiler.py # Main profiler backend class
│ └── profiler_utils.py # FunctionEvent utils
├── csrc/ # C and C++ source code
│ └── profiler/ # Profiler C++ source code
│ ├── collection.cpp # Main collection logic
│ ├── collection.h # Collection definitions
│ ├── kineto_client_interface.cpp # Interface to call Profiler from kineto (on-demand only)
│ ├── kineto_client_interface.h # Client interface definitions
│ ├── kineto_shim.cpp # Shim to call kineto from profiler
│ ├── kineto_shim.h # Shim definitions
│ ├── util.cpp # utils for handling args in profiler events
│ ├── util.h # util definitions
│ └── README.md # This file
│ └── autograd/ # Autograd C++ source code
│ ├── profiler_python.cpp # Main python stack collection logic
│ ├── profiler_python.h # Python stack collection definitions
│ ├── profiler_kineto.cpp # Profiler backend logic for starting collection/kineto
│ └── profiler_kineto.h # Profiler backend definitions for starting collection/kineto
│ └── ATen/ # ATen C++ source code
│ ├── record_function.cpp # RecordFunction collection logic
│ └── record_function.h # RecordFunction definitions
└── LICENSE # License information
```
## `RecordFunction` ##
[aten/src/ATen/record_function.h](../../../aten/src/ATen/record_function.h)
@ -43,14 +78,39 @@ The profiler records two pieces of information from the autograd engine:
(\*) Note that only op invocations whose inputs require gradients are assigned a sequence number
## Collection and Post-Processing ##
## Torch Operation Collection ##
This section describes the general flow for collecting torch operations during auto-trace (in-process, synchronous tracing). For details on on-demand tracing (out-of-process, asynchronous), please refer to the Libkineto README.
TODO
When a trace begins, the autograd/profiler backend calls into `profiler_kineto.cpp` to prepare, start, or stop collection. At the start of tracing, the `onFunctionEnter` and `onFunctionExit` callbacks defined in `profiler_kineto.cpp` are registered.
Callback registration can be either global or local, depending on the `ExperimentalConfig` used:
- **Global:** The callback is registered to all threads throughout execution.
- **Local:** The callback is registered only to threads present *at the start* of tracing.
Within `onFunctionEnter`, the profiler creates a `ThreadLocalSubqueue` instance for each thread, ensuring that each CPU operation is associated with the thread on which it was executed. When a torch operation is entered, the profiler calls `begin_op` (defined in `collection.cpp`) to record the necessary information. The `begin_op` routine is intentionally lightweight, as it is on the "hot path" during profiling. Excessive overhead here would distort the profile and reduce its usefulness. Therefore, only minimal information is collected during the callback; most logic occurs during post-processing.
## Allocation Event Collection ##
Unlike torch operations, which have a start and stop, allocation events are represented as `cpu_instant_event` (zero duration). As a result, `RecordFunction` is bypassed for these events. Instead, `emplace_allocation_event` is called directly to enqueue the event into the appropriate `ThreadLocalSubqueue`.
## Kineto Integration ##
TODO
Kineto serves as an abstraction layer for collecting events across multiple architectures. It interacts with libraries such as CUPTI to receive GPU and accelerator events, which are then forwarded to the frontend profiler. Kineto requires time to "prepare" (also referred to as "warmup") these third-party modules to avoid distorting the profile with initialization routines. While this could theoretically be done at job startup, keeping a heavy library like CUPTI running unnecessarily introduces significant overhead.
As previously mentioned, `profiler_kineto.cpp` is used in the backend to invoke the appropriate profiler stage. It also calls into `kineto_shim.cpp`, which triggers the corresponding routines in Kineto. Once a trace is complete, all events collected by Kineto are forwarded to the profiler for two main reasons:
1. To coalesce all data and complete any post-processing between profiler and Kineto events.
2. To forward these events to the Python frontend as `FunctionEvents`.
The final step in integration is file export. After all events have been collected and post-processed, they can be exported to a JSON file for visualization in Perfetto or Chrome Tracer. This is done by calling Kineto's `ActivityTraceInterface::save`, which writes all event information to disk.
## Python Tracing ##
TODO
When `with_stack=True` is set in the profiler, the Python stack tracer is generated using the `make` function defined in `PythonTracerBase`. The implementation resides in `profiler_python.cpp`.
To profile the stack, `PyEval_SetProfile` is used to trace and handle various execution events within a Python program. This enables comprehensive profiling by monitoring and responding to specific cases:
- **Python Function Calls (`PyTrace_CALL`):** The `recordPyCall` method logs each Python function call, capturing essential details for later analysis.
- **C Function Calls (`PyTrace_C_CALL`):** The `recordCCall` method documents calls to C functions, including relevant arguments, providing a complete view of the program's execution flow.
- **Python Function Returns (`PyTrace_RETURN`):** Exit times of Python functions are recorded, enabling precise measurement of function execution durations.
- **C Function Returns and Exceptions (`PyTrace_C_RETURN` and `PyTrace_C_EXCEPTION`):** Exit times for C functions are tracked, whether they conclude normally or due to an exception, ensuring all execution paths are accounted for.
This setup allows for detailed and accurate data collection on both Python and C function executions, facilitating thorough post-processing and analysis. After profiling, the accumulated event stacks are processed to match entrances and exits, constructing complete events for further analysis by the profiler.
**Note:** For Python 3.12.03.12.4, a bug in CPython requires the use of `sys.monitoring` as a workaround.
## Clock Alignment ##
Depending on the system environment, the profiler will use the most efficient clock when creating a timestamp. The default for most Linux systems is TSC, which records time in the form of CPU cycles. To convert from this time to the unix time in nanoseconds, we create a clock converter. If Kineto is included in the profiler, this converter will also be passed into Kineto as well to ensure alignment.