pytorch/torch
penknife6153 59eb61b2d1 [inductor] Improve GEMM logging to display batch size for batched operations (#155544)
Improves the GEMM overview logging in PyTorch Inductor to properly display batch size information for batched matrix operations like `torch.bmm` and `torch.baddbmm`.

**Fixes #155307**

## Problem

The current GEMM logging for `torch.bmm` shows:
```python
# Repro
import os
os.environ["TORCH_LOGS"] = "inductor"
import torch

M, N, K = 1024, 1024, 1024
dtype = torch.bfloat16
A = torch.randn(10, M, K, device="cuda", dtype=dtype)
B = torch.randn(10, K, N, device="cuda", dtype=dtype)

compiled_model = torch.compile(torch.bmm, fullgraph=True)
_ = compiled_model(A, B)
```

**Before:**
```
Name                 | M                    | N                    | K                    | Count
----------------------------------------------------------------------------------------------------
aten.bmm             | 1024                 | 1024                 | 1024                 | 1
----------------------------------------------------------------------------------------------------
```

The batch size (10) is missing from the logs, making it unclear what the actual operation dimensions were.

## Solution

**After:**
```
Name                           | B                    | M                    | N                    | K                    | Count
----------------------------------------------------------------------------------------------------------------------------------
aten.bmm                      | 10                   | 1024                 | 1024                 | 1024                 | 1
aten.mm                       | -                    | 1024                 | 1024                 | 1024                 | 2
----------------------------------------------------------------------------------------------------------------------------------
```

## Changes Made

### 1. Enhanced Parsing Logic in compile_fx.py
- Detects batched operations by checking if operation name ends with `'bmm'` or `'baddbmm'`
- For batched operations: takes last 4 parts as `batch, m, n, k`
- For non-batched operations: takes last 3 parts as `m, n, k`
- **Dedicated "B" column**: Added separate column for batch size instead of embedding in operation name
- Shows batch size for batched operations, shows "-" for non-batched operations

### 2. Updated All MM Operations for Consistency
- **bmm.py**:
  - Extract batch size from `mat1.get_size()[0]` for both `tuned_bmm` and `tuned_baddbmm`
  - Use positional counter keys: `aten.bmm_{batch_size}_{m}_{n}_{k}`
  - Enhanced log messages to include batch size information

- **mm.py**: Updated counter keys for consistency:
  - `aten.mm_{m}_{n}_{k}` (no batch dimension)
  - `aten.addmm_{m}_{n}_{k}` (no batch dimension)
  - `aten._int_mm_{m}_{n}_{k}` (no batch dimension)
  - `aten._scaled_mm.default_{m}_{n}_{k}` (no batch dimension)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155544
Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
2025-06-11 16:57:40 +00:00
..
_awaits
_C [BE]: Backport runtime_checkable perf improvements/behavior from 3.12 (#155130) 2025-06-06 13:28:05 +00:00
_C_flatbuffer
_custom_op
_decomp Fix clamp type promotion in inductor decomposition (#154471) 2025-05-28 23:24:25 +00:00
_dispatch
_dynamo Replace frame_traced_fn hook with get_traced_code() util (#155249) 2025-06-10 22:40:58 +00:00
_export [Export] Add math module for deserialization (#154643) 2025-05-30 17:29:25 +00:00
_functorch [invoke_subgraph] Use eager input vals to constrain input strides (#155291) 2025-06-10 04:06:09 +00:00
_higher_order_ops [user triton] mutation analysis for on-device TMA (#155380) 2025-06-10 00:07:18 +00:00
_inductor [inductor] Improve GEMM logging to display batch size for batched operations (#155544) 2025-06-11 16:57:40 +00:00
_lazy
_library Custom Op handle 1-element tuples (#155447) 2025-06-11 03:43:40 +00:00
_logging [invoke_subgraph] Add logging (#155284) 2025-06-07 11:31:53 +00:00
_numpy fix numpy compatibility for 2d small list indices (#154806) 2025-06-04 01:58:52 +00:00
_prims [dynamic shapes] unbacked safe unsqueeze (#154087) 2025-05-30 01:41:57 +00:00
_prims_common [export] support linear & layer_norm unbacked (#155260) 2025-06-11 16:47:34 +00:00
_refs [export] support linear & layer_norm unbacked (#155260) 2025-06-11 16:47:34 +00:00
_strobelight
_subclasses [export] support linear & layer_norm unbacked (#155260) 2025-06-11 16:47:34 +00:00
_vendor
accelerator Add torch.accelerator.device_index as accelerator's device switch context (#148864) 2025-04-25 09:45:25 +00:00
amp
ao Fix incorrect get_default_qat_qconfig in prepare_qat_fx docs. (#155100) 2025-06-04 18:51:40 +00:00
autograd set_grad_enabled add str and repr for prints (#155681) 2025-06-11 16:01:03 +00:00
backends Revert "refine fp32 precision api (#125888)" 2025-05-11 00:35:46 +00:00
compiler Combine sticky pgo key with job id (#154863) 2025-06-03 07:58:38 +00:00
contrib
cpu [device_mesh] improve device selection logic (#150897) 2025-05-14 06:29:16 +00:00
csrc Replace TORCH_INTERNAL_ASSERT with TORCH_CHECK in set_history (#155453) 2025-06-11 03:46:48 +00:00
cuda [Memory Snapshot] Add Flag to Toggle Global and Local Callbacks for Annotations (#154932) 2025-06-04 23:15:19 +00:00
distributed Changes to HFStorageWriter to support saving shards of tensors (#154742) (#155566) 2025-06-10 23:37:47 +00:00
distributions Type hints for distributions/utils (#154712) 2025-05-30 15:50:31 +00:00
export [Export] Fix some typos in docstring (#155485) 2025-06-11 16:44:38 +00:00
fft
func
futures Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
fx Include c++ stack traces when we hit constraint violation (#155603) 2025-06-11 05:00:36 +00:00
jit Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
legacy
lib Revert "Use 3.27 as the minimum CMake version (#153153)" 2025-05-31 02:14:24 +00:00
linalg Fix for ambiguity in linalg.norm()'s ord argument of +2 & -2 (#155148) 2025-06-04 21:15:20 +00:00
masked Add torch.Tensor._make_wrapper_subclass to torch/_C/__init__.pyi (#154022) 2025-05-27 14:10:00 +00:00
monitor
mps
mtia Add getDeviceProperties api to torch mtia device (#153577) 2025-05-27 11:55:58 +00:00
multiprocessing
nativert [nativert] move execution planner to torch (#155374) 2025-06-10 22:36:06 +00:00
nested Add torch.Tensor._make_wrapper_subclass to torch/_C/__init__.pyi (#154022) 2025-05-27 14:10:00 +00:00
nn Fix docs build (#155129) 2025-06-09 22:25:20 +00:00
onnx [ONNX] Set the name of the producing node using the value name (#155413) 2025-06-09 13:03:58 +00:00
optim Fix lr_scheduler unexpectedly calls step() when init argument last_epoch is larger than -1 (#149312) 2025-05-22 08:42:37 +00:00
package [BE]: Enable ruff YTT linter for Python version checks (#153547) 2025-05-14 21:09:16 +00:00
profiler [Profiler] Induce Inductor Import before Profiling (#155243) 2025-06-07 23:58:50 +00:00
quantization
signal
sparse fix numpy compatibility for 2d small list indices (#154806) 2025-06-04 01:58:52 +00:00
special Add doc for missing functions for torch.special module (#155074) 2025-06-09 22:28:26 +00:00
standalone/macros Move c10/macros/Export.h to torch/standalone (#154850) 2025-06-03 06:18:59 +00:00
testing [triton pin][tests] refactor test_triton_kernel.py tests to test new & old API (#155510) 2025-06-11 13:52:15 +00:00
utils Revert "Add Intel GPU info collection to the collect env script (#137846)" 2025-06-11 15:18:47 +00:00
xpu Correct torch.xpu.is_bf16_supported return False if no XPU detected (#152317) 2025-05-06 10:03:17 +00:00
__config__.py
__future__.py
__init__.py Replace frame_traced_fn hook with get_traced_code() util (#155249) 2025-06-10 22:40:58 +00:00
_appdirs.py Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_classes.py
_compile.py [precompile] Ensure @disable()-ed function won't trigger recompile from precompile bytecode. (#155363) 2025-06-10 16:13:38 +00:00
_custom_ops.py Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
_deploy.py
_environment.py
_guards.py Replace frame_traced_fn hook with get_traced_code() util (#155249) 2025-06-10 22:40:58 +00:00
_jit_internal.py BE: Type previously untyped decorators (#154515) 2025-05-29 00:36:34 +00:00
_linalg_utils.py
_lobpcg.py Fixed rerr computation in lobpcg (#152789) 2025-05-08 12:22:31 +00:00
_lowrank.py
_meta_registrations.py Revert "Update auto-tuning support for _scaled_grouped_mm (#150944)" 2025-06-09 23:12:56 +00:00
_namedtensor_internals.py
_ops.py Revert "Improve torch.ops typing (#153558)" 2025-05-19 23:32:36 +00:00
_python_dispatcher.py
_size_docs.py Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
_sources.py
_storage_docs.py Fix docstring for torch.UntypedStorage.from_file (#155067) 2025-06-05 14:30:49 +00:00
_streambase.py
_tensor_docs.py [docs] Add docstring indicating UB for converting inf to int (#154781) 2025-06-10 14:04:50 +00:00
_tensor_str.py
_tensor.py Avoid triggering ignored requires_grad warning in our code (#152686) 2025-05-05 23:56:40 +00:00
_thread_safe_fork.py
_torch_docs.py Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
_utils_internal.py Revert "Inductor logging + analysis of torch.profile (#149697)" 2025-06-10 15:38:40 +00:00
_utils.py User-controlled sparse tensor validation when loading data from external storage (#154610) 2025-06-02 10:17:07 +00:00
_VF.py
_vmap_internals.py Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_weights_only_unpickler.py
CMakeLists.txt Move c10/macros/Export.h to torch/standalone (#154850) 2025-06-03 06:18:59 +00:00
custom_class_detail.h
custom_class.h
extension.h
functional.py
header_only_apis.txt Move c10/macros/Export.h to torch/standalone (#154850) 2025-06-03 06:18:59 +00:00
hub.py
library.h
library.py Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
overrides.py
py.typed
quasirandom.py
random.py
return_types.py
script.h
serialization.py Update serialization docs (#153631) 2025-05-19 20:22:07 +00:00
storage.py
torch_version.py
types.py
version.py.tpl