mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improves the GEMM overview logging in PyTorch Inductor to properly display batch size information for batched matrix operations like `torch.bmm` and `torch.baddbmm`.
**Fixes #155307**
## Problem
The current GEMM logging for `torch.bmm` shows:
```python
# Repro
import os
os.environ["TORCH_LOGS"] = "inductor"
import torch
M, N, K = 1024, 1024, 1024
dtype = torch.bfloat16
A = torch.randn(10, M, K, device="cuda", dtype=dtype)
B = torch.randn(10, K, N, device="cuda", dtype=dtype)
compiled_model = torch.compile(torch.bmm, fullgraph=True)
_ = compiled_model(A, B)
```
**Before:**
```
Name | M | N | K | Count
----------------------------------------------------------------------------------------------------
aten.bmm | 1024 | 1024 | 1024 | 1
----------------------------------------------------------------------------------------------------
```
The batch size (10) is missing from the logs, making it unclear what the actual operation dimensions were.
## Solution
**After:**
```
Name | B | M | N | K | Count
----------------------------------------------------------------------------------------------------------------------------------
aten.bmm | 10 | 1024 | 1024 | 1024 | 1
aten.mm | - | 1024 | 1024 | 1024 | 2
----------------------------------------------------------------------------------------------------------------------------------
```
## Changes Made
### 1. Enhanced Parsing Logic in compile_fx.py
- Detects batched operations by checking if operation name ends with `'bmm'` or `'baddbmm'`
- For batched operations: takes last 4 parts as `batch, m, n, k`
- For non-batched operations: takes last 3 parts as `m, n, k`
- **Dedicated "B" column**: Added separate column for batch size instead of embedding in operation name
- Shows batch size for batched operations, shows "-" for non-batched operations
### 2. Updated All MM Operations for Consistency
- **bmm.py**:
- Extract batch size from `mat1.get_size()[0]` for both `tuned_bmm` and `tuned_baddbmm`
- Use positional counter keys: `aten.bmm_{batch_size}_{m}_{n}_{k}`
- Enhanced log messages to include batch size information
- **mm.py**: Updated counter keys for consistency:
- `aten.mm_{m}_{n}_{k}` (no batch dimension)
- `aten.addmm_{m}_{n}_{k}` (no batch dimension)
- `aten._int_mm_{m}_{n}_{k}` (no batch dimension)
- `aten._scaled_mm.default_{m}_{n}_{k}` (no batch dimension)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155544
Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
|
||
|---|---|---|
| .. | ||
| autoheuristic | ||
| codegen | ||
| compile_worker | ||
| fx_passes | ||
| kernel | ||
| package | ||
| runtime | ||
| __autotune_main__.py | ||
| __init__.py | ||
| analyze_preserves_zero_mask.py | ||
| aoti_eager.py | ||
| async_compile.py | ||
| autotune_process.py | ||
| bounds.py | ||
| choices.py | ||
| codecache.py | ||
| comm_analysis.py | ||
| comm_lowering.py | ||
| comms.py | ||
| compile_fx_async.py | ||
| compile_fx_ext.py | ||
| compile_fx_subproc.py | ||
| compile_fx.py | ||
| compiler_bisector.py | ||
| config.py | ||
| constant_folding.py | ||
| cpp_builder.py | ||
| cpu_vec_isa.py | ||
| cudagraph_trees.py | ||
| cudagraph_utils.py | ||
| custom_graph_pass.py | ||
| debug.py | ||
| decomposition.py | ||
| dependencies.py | ||
| dtype_propagation.py | ||
| exc.py | ||
| extern_node_serializer.py | ||
| freezing_utils.py | ||
| freezing.py | ||
| fuzzer.py | ||
| fx_utils.py | ||
| graph.py | ||
| hooks.py | ||
| index_propagation.py | ||
| inductor_prims.py | ||
| ir.py | ||
| jagged_lowerings.py | ||
| loop_body.py | ||
| lowering.py | ||
| memory.py | ||
| metrics.py | ||
| mkldnn_ir.py | ||
| mkldnn_lowerings.py | ||
| mock_cache.py | ||
| ops_handler.py | ||
| optimize_indexing.py | ||
| output_code.py | ||
| pattern_matcher.py | ||
| quantized_lowerings.py | ||
| remote_cache.py | ||
| scheduler.py | ||
| script.ld | ||
| select_algorithm.py | ||
| sizevars.py | ||
| standalone_compile.py | ||
| subgraph_lowering.py | ||
| template_heuristics.py | ||
| test_case.py | ||
| test_operators.py | ||
| tiling_utils.py | ||
| triton_bundler.py | ||
| utils.py | ||
| virtualized.py | ||
| wrapper_benchmark.py | ||