pytorch/torch/_inductor
penknife6153 59eb61b2d1 [inductor] Improve GEMM logging to display batch size for batched operations (#155544)
Improves the GEMM overview logging in PyTorch Inductor to properly display batch size information for batched matrix operations like `torch.bmm` and `torch.baddbmm`.

**Fixes #155307**

## Problem

The current GEMM logging for `torch.bmm` shows:
```python
# Repro
import os
os.environ["TORCH_LOGS"] = "inductor"
import torch

M, N, K = 1024, 1024, 1024
dtype = torch.bfloat16
A = torch.randn(10, M, K, device="cuda", dtype=dtype)
B = torch.randn(10, K, N, device="cuda", dtype=dtype)

compiled_model = torch.compile(torch.bmm, fullgraph=True)
_ = compiled_model(A, B)
```

**Before:**
```
Name                 | M                    | N                    | K                    | Count
----------------------------------------------------------------------------------------------------
aten.bmm             | 1024                 | 1024                 | 1024                 | 1
----------------------------------------------------------------------------------------------------
```

The batch size (10) is missing from the logs, making it unclear what the actual operation dimensions were.

## Solution

**After:**
```
Name                           | B                    | M                    | N                    | K                    | Count
----------------------------------------------------------------------------------------------------------------------------------
aten.bmm                      | 10                   | 1024                 | 1024                 | 1024                 | 1
aten.mm                       | -                    | 1024                 | 1024                 | 1024                 | 2
----------------------------------------------------------------------------------------------------------------------------------
```

## Changes Made

### 1. Enhanced Parsing Logic in compile_fx.py
- Detects batched operations by checking if operation name ends with `'bmm'` or `'baddbmm'`
- For batched operations: takes last 4 parts as `batch, m, n, k`
- For non-batched operations: takes last 3 parts as `m, n, k`
- **Dedicated "B" column**: Added separate column for batch size instead of embedding in operation name
- Shows batch size for batched operations, shows "-" for non-batched operations

### 2. Updated All MM Operations for Consistency
- **bmm.py**:
  - Extract batch size from `mat1.get_size()[0]` for both `tuned_bmm` and `tuned_baddbmm`
  - Use positional counter keys: `aten.bmm_{batch_size}_{m}_{n}_{k}`
  - Enhanced log messages to include batch size information

- **mm.py**: Updated counter keys for consistency:
  - `aten.mm_{m}_{n}_{k}` (no batch dimension)
  - `aten.addmm_{m}_{n}_{k}` (no batch dimension)
  - `aten._int_mm_{m}_{n}_{k}` (no batch dimension)
  - `aten._scaled_mm.default_{m}_{n}_{k}` (no batch dimension)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155544
Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
2025-06-11 16:57:40 +00:00
..
autoheuristic [BE][Ez]: Optimize unnecessary lambda with operator (#154722) 2025-05-30 23:47:10 +00:00
codegen [Cutlass] Include fp8 headers in aoti cpp wrapper (#155173) 2025-06-11 01:21:16 +00:00
compile_worker torch.compile: Supress stdout / stderr output from subprocesses when local (#153837) 2025-05-22 05:49:43 +00:00
fx_passes [Graph Partition] move cpu scalar tensor to gpu (#154464) 2025-06-11 10:22:45 +00:00
kernel [inductor] Improve GEMM logging to display batch size for batched operations (#155544) 2025-06-11 16:57:40 +00:00
package [export] Refactor pt2 save/load (#152495) 2025-06-04 06:04:29 +00:00
runtime Revert "[flex attention][triton pin] triton_helpers shim for TMA apis (#154858)" (#155640) 2025-06-11 07:37:47 +00:00
__autotune_main__.py Improve subproc autotuning implementation (#149700) 2025-03-28 01:06:39 +00:00
__init__.py Add optional device index to AOTIModelPackageLoader (#152093) 2025-05-04 11:40:12 +00:00
analyze_preserves_zero_mask.py Revert two recent prologue prs (#151013) 2025-04-10 23:48:41 +00:00
aoti_eager.py
async_compile.py Redo D75092426: [internal] Expose additional metadata to compilation callbacks (#155063) 2025-06-05 23:40:31 +00:00
autotune_process.py Add torch.profile benchmarking function to feedback_fns (#153579) 2025-05-29 21:43:45 +00:00
bounds.py
choices.py
codecache.py inductor codecache: include private inductor configs in cache key (#153672) 2025-06-11 01:33:24 +00:00
comm_analysis.py
comm_lowering.py Revert "[inductor] Add typing to _inductor/ir.py (#149958)" 2025-06-06 15:19:16 +00:00
comms.py [PT2][comms] put visualize_overlap in a try-except block (#155222) 2025-06-05 23:39:48 +00:00
compile_fx_async.py
compile_fx_ext.py Re-enable FakeTensor caching for SymInts (#152662) 2025-05-30 17:23:36 +00:00
compile_fx_subproc.py
compile_fx.py [inductor] Improve GEMM logging to display batch size for batched operations (#155544) 2025-06-11 16:57:40 +00:00
compiler_bisector.py
config.py inductor codecache: include private inductor configs in cache key (#153672) 2025-06-11 01:33:24 +00:00
constant_folding.py Add dont constant fold flag (#154945) 2025-06-10 14:52:26 +00:00
cpp_builder.py [AOTI] Fix embed_kernel_binary error when max_autotune is ON (#155569) 2025-06-11 12:27:36 +00:00
cpu_vec_isa.py Allow to set custom PYTHONPATH for torch.inductor (#152832) 2025-05-15 06:35:41 +00:00
cudagraph_trees.py Redo D75092426: [internal] Expose additional metadata to compilation callbacks (#155063) 2025-06-05 23:40:31 +00:00
cudagraph_utils.py [CUDAGraph] support meta tensor (#150478) 2025-04-02 07:21:50 +00:00
custom_graph_pass.py Revert "Custom FX pass for inductor's backend registration (#154841)" 2025-06-09 16:56:45 +00:00
debug.py Rename the provenance tracing artifact name for kernel <-> post_grad nodes mapping (#154046) 2025-05-22 19:20:56 +00:00
decomposition.py Fix clamp type promotion in inductor decomposition (#154471) 2025-05-28 23:24:25 +00:00
dependencies.py Revert "[inductor] Add typing to _inductor/ir.py (#149958)" 2025-06-06 15:19:16 +00:00
dtype_propagation.py Remove libdevice ops in inductor (#151562) 2025-04-17 22:18:00 +00:00
exc.py
extern_node_serializer.py Back out "[AOTI] Always use oss schema for ExternKernelNodes serialization" (#151026) 2025-04-10 22:36:35 +00:00
freezing_utils.py
freezing.py [cudagraphs] Fix issue in collecting static_input_idxs (#152287) 2025-04-30 03:24:05 +00:00
fuzzer.py [AOTI][reland] Add an option to specify custom op C shim (#153968) 2025-05-21 15:57:57 +00:00
fx_utils.py Revert "Inductor logging + analysis of torch.profile (#149697)" 2025-06-10 15:38:40 +00:00
graph.py Revert "Inductor logging + analysis of torch.profile (#149697)" 2025-06-10 15:38:40 +00:00
hooks.py
index_propagation.py
inductor_prims.py [inductor] lowering for fractional_max_pool3d (#148630) 2025-05-22 16:06:29 +00:00
ir.py [inductor] use int64 for large index (#154575) 2025-06-10 18:30:43 +00:00
jagged_lowerings.py Revert "[inductor] Add typing to _inductor/ir.py (#149958)" 2025-06-06 15:19:16 +00:00
loop_body.py [Tiling rewrite pt1] Normalize reads and writes to common iter space (#153723) 2025-06-03 14:04:34 +00:00
lowering.py [invoke_subgraph] Use eager input vals to constrain input strides (#155291) 2025-06-10 04:06:09 +00:00
memory.py [PT2][memory] correct wait tensor output size (#153569) 2025-06-04 17:49:25 +00:00
metrics.py Replace runtime type parameterization (#155221) 2025-06-05 21:43:54 +00:00
mkldnn_ir.py Revert "[Inductor] Improve typing, and prepare for ABI-compatible AOTI C-shim dispatching (#154371)" 2025-06-08 17:37:29 +00:00
mkldnn_lowerings.py Revert "[inductor] Add typing to _inductor/ir.py (#149958)" 2025-06-06 15:19:16 +00:00
mock_cache.py
ops_handler.py Revert "[inductor] Add typing to _inductor/ir.py (#149958)" 2025-06-06 15:19:16 +00:00
optimize_indexing.py
output_code.py Reflect back mutation if we clone misaligned tensors (#154442) 2025-05-29 13:36:48 +00:00
pattern_matcher.py Replace runtime type parameterization (#155221) 2025-06-05 21:43:54 +00:00
quantized_lowerings.py [Inductor]Cleanup autotune_fallback_to_aten post-deprecation (#154331) 2025-05-29 20:29:58 +00:00
remote_cache.py [Indcutor Remote Cache] Raise an exception if redis module is required but not available (#151779) 2025-04-26 11:21:54 +00:00
scheduler.py Revert "Inductor logging + analysis of torch.profile (#149697)" 2025-06-10 15:38:40 +00:00
script.ld
select_algorithm.py [logs] Change autotune data into separate items (#155525) 2025-06-10 21:47:07 +00:00
sizevars.py Revert "[inductor] Add typing to _inductor/ir.py (#149958)" 2025-06-06 15:19:16 +00:00
standalone_compile.py Add logging for guard miss failure (#153125) 2025-05-09 16:51:04 +00:00
subgraph_lowering.py
template_heuristics.py [Inductor] Add Additional Configs for persistent+TMA version of Triton mm and addmm (#150587) 2025-04-23 18:21:35 +00:00
test_case.py
test_operators.py [CI] Fix GPUTests.test_scheduler_vertical_fusion1 (#151166) 2025-04-13 00:41:51 +00:00
tiling_utils.py Turn on new tiling by default (#154768) 2025-06-06 21:19:35 +00:00
triton_bundler.py Keep raw cubin file around in case it gets deleted underneath us (#153064) 2025-05-08 14:29:19 +00:00
utils.py [inductor] use int64 for large index (#154575) 2025-06-10 18:30:43 +00:00
virtualized.py
wrapper_benchmark.py Revert "Inductor logging + analysis of torch.profile (#149697)" 2025-06-10 15:38:40 +00:00