[inductor] Improve GEMM logging to display batch size for batched operations (#155544)

Improves the GEMM overview logging in PyTorch Inductor to properly display batch size information for batched matrix operations like `torch.bmm` and `torch.baddbmm`.

**Fixes #155307**

## Problem

The current GEMM logging for `torch.bmm` shows:
```python
# Repro
import os
os.environ["TORCH_LOGS"] = "inductor"
import torch

M, N, K = 1024, 1024, 1024
dtype = torch.bfloat16
A = torch.randn(10, M, K, device="cuda", dtype=dtype)
B = torch.randn(10, K, N, device="cuda", dtype=dtype)

compiled_model = torch.compile(torch.bmm, fullgraph=True)
_ = compiled_model(A, B)
```

**Before:**
```
Name                 | M                    | N                    | K                    | Count
----------------------------------------------------------------------------------------------------
aten.bmm             | 1024                 | 1024                 | 1024                 | 1
----------------------------------------------------------------------------------------------------
```

The batch size (10) is missing from the logs, making it unclear what the actual operation dimensions were.

## Solution

**After:**
```
Name                           | B                    | M                    | N                    | K                    | Count
----------------------------------------------------------------------------------------------------------------------------------
aten.bmm                      | 10                   | 1024                 | 1024                 | 1024                 | 1
aten.mm                       | -                    | 1024                 | 1024                 | 1024                 | 2
----------------------------------------------------------------------------------------------------------------------------------
```

## Changes Made

### 1. Enhanced Parsing Logic in compile_fx.py
- Detects batched operations by checking if operation name ends with `'bmm'` or `'baddbmm'`
- For batched operations: takes last 4 parts as `batch, m, n, k`
- For non-batched operations: takes last 3 parts as `m, n, k`
- **Dedicated "B" column**: Added separate column for batch size instead of embedding in operation name
- Shows batch size for batched operations, shows "-" for non-batched operations

### 2. Updated All MM Operations for Consistency
- **bmm.py**:
  - Extract batch size from `mat1.get_size()[0]` for both `tuned_bmm` and `tuned_baddbmm`
  - Use positional counter keys: `aten.bmm_{batch_size}_{m}_{n}_{k}`
  - Enhanced log messages to include batch size information

- **mm.py**: Updated counter keys for consistency:
  - `aten.mm_{m}_{n}_{k}` (no batch dimension)
  - `aten.addmm_{m}_{n}_{k}` (no batch dimension)
  - `aten._int_mm_{m}_{n}_{k}` (no batch dimension)
  - `aten._scaled_mm.default_{m}_{n}_{k}` (no batch dimension)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155544
Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
This commit is contained in:
penknife6153 2025-06-11 16:57:40 +00:00 committed by PyTorch MergeBot
parent 7b7cd56f5e
commit 59eb61b2d1
2 changed files with 34 additions and 12 deletions

View File

@ -999,19 +999,37 @@ def _compile_fx_inner(
if log.isEnabledFor(logging.INFO):
mm_table_data = []
for key, value in counters["aten_mm_info"].items():
m, n, k = key.split("_")[-3:]
name = "_".join(key.split("_")[:-3])
mm_table_data.append([name, m, n, k, value])
parts = key.split("_")
if len(parts) < 3:
# Unexpected format, show as-is
mm_table_data.append([key, "-", "?", "?", "?", value])
continue
# Determine if this is a batched operation by checking the operation name
name = "_".join(parts[:-4]) if len(parts) >= 4 else "_".join(parts[:-3])
is_batched = name.endswith(("bmm", "baddbmm"))
if is_batched and len(parts) >= 4:
# Batched operation: last 4 parts are batch, m, n, k
batch, m, n, k = parts[-4:]
name = "_".join(parts[:-4])
mm_table_data.append([name, batch, m, n, k, value])
else:
# Non-batched operation: last 3 parts are m, n, k
m, n, k = parts[-3:]
name = "_".join(parts[:-3])
mm_table_data.append([name, "-", m, n, k, value])
log.info("Overview info of inductor aten mms: ")
log.info(
"{:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format( # noqa: G001
"Name", "M", "N", "K", "Count"
"{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format( # noqa: G001
"Name", "B", "M", "N", "K", "Count"
)
)
log.info("-" * 100)
log.info("-" * 130)
for row in mm_table_data:
log.info("{:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001
log.info("-" * 100)
log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001
log.info("-" * 130)
# Not strictly necessary, but good to clean up straggling futures
# that are unused to reclaim memory.

View File

@ -179,9 +179,11 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
)
# below is for getting an overview logging info of inductor mms
counters["aten_mm_info"][f"aten.bmm_{m}_{n}_{k}"] += 1
batch_size = mat1.get_size()[0] # Extract batch dimension
counters["aten_mm_info"][f"aten.bmm_{batch_size}_{m}_{n}_{k}"] += 1
log.info(
"Tuned aten.bmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
"Tuned aten.bmm: batch=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
batch_size,
m,
n,
k,
@ -241,9 +243,11 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
# below is for getting an overview logging info of inductor mms
counters["aten_mm_info"][f"aten.baddbmm_{m}_{n}_{k}"] += 1
batch_size = mat1.get_size()[0]
counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1
log.info(
"Tuned aten.baddbmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, inp=%s, output_layout=%s",
"Tuned aten.baddbmm: batch_size=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, inp=%s, output_layout=%s",
batch_size,
m,
n,
k,