mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Improve GEMM logging to display batch size for batched operations (#155544)
Improves the GEMM overview logging in PyTorch Inductor to properly display batch size information for batched matrix operations like `torch.bmm` and `torch.baddbmm`.
**Fixes #155307**
## Problem
The current GEMM logging for `torch.bmm` shows:
```python
# Repro
import os
os.environ["TORCH_LOGS"] = "inductor"
import torch
M, N, K = 1024, 1024, 1024
dtype = torch.bfloat16
A = torch.randn(10, M, K, device="cuda", dtype=dtype)
B = torch.randn(10, K, N, device="cuda", dtype=dtype)
compiled_model = torch.compile(torch.bmm, fullgraph=True)
_ = compiled_model(A, B)
```
**Before:**
```
Name | M | N | K | Count
----------------------------------------------------------------------------------------------------
aten.bmm | 1024 | 1024 | 1024 | 1
----------------------------------------------------------------------------------------------------
```
The batch size (10) is missing from the logs, making it unclear what the actual operation dimensions were.
## Solution
**After:**
```
Name | B | M | N | K | Count
----------------------------------------------------------------------------------------------------------------------------------
aten.bmm | 10 | 1024 | 1024 | 1024 | 1
aten.mm | - | 1024 | 1024 | 1024 | 2
----------------------------------------------------------------------------------------------------------------------------------
```
## Changes Made
### 1. Enhanced Parsing Logic in compile_fx.py
- Detects batched operations by checking if operation name ends with `'bmm'` or `'baddbmm'`
- For batched operations: takes last 4 parts as `batch, m, n, k`
- For non-batched operations: takes last 3 parts as `m, n, k`
- **Dedicated "B" column**: Added separate column for batch size instead of embedding in operation name
- Shows batch size for batched operations, shows "-" for non-batched operations
### 2. Updated All MM Operations for Consistency
- **bmm.py**:
- Extract batch size from `mat1.get_size()[0]` for both `tuned_bmm` and `tuned_baddbmm`
- Use positional counter keys: `aten.bmm_{batch_size}_{m}_{n}_{k}`
- Enhanced log messages to include batch size information
- **mm.py**: Updated counter keys for consistency:
- `aten.mm_{m}_{n}_{k}` (no batch dimension)
- `aten.addmm_{m}_{n}_{k}` (no batch dimension)
- `aten._int_mm_{m}_{n}_{k}` (no batch dimension)
- `aten._scaled_mm.default_{m}_{n}_{k}` (no batch dimension)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155544
Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
This commit is contained in:
parent
7b7cd56f5e
commit
59eb61b2d1
|
|
@ -999,19 +999,37 @@ def _compile_fx_inner(
|
|||
if log.isEnabledFor(logging.INFO):
|
||||
mm_table_data = []
|
||||
for key, value in counters["aten_mm_info"].items():
|
||||
m, n, k = key.split("_")[-3:]
|
||||
name = "_".join(key.split("_")[:-3])
|
||||
mm_table_data.append([name, m, n, k, value])
|
||||
parts = key.split("_")
|
||||
if len(parts) < 3:
|
||||
# Unexpected format, show as-is
|
||||
mm_table_data.append([key, "-", "?", "?", "?", value])
|
||||
continue
|
||||
|
||||
# Determine if this is a batched operation by checking the operation name
|
||||
name = "_".join(parts[:-4]) if len(parts) >= 4 else "_".join(parts[:-3])
|
||||
is_batched = name.endswith(("bmm", "baddbmm"))
|
||||
|
||||
if is_batched and len(parts) >= 4:
|
||||
# Batched operation: last 4 parts are batch, m, n, k
|
||||
batch, m, n, k = parts[-4:]
|
||||
name = "_".join(parts[:-4])
|
||||
mm_table_data.append([name, batch, m, n, k, value])
|
||||
else:
|
||||
# Non-batched operation: last 3 parts are m, n, k
|
||||
m, n, k = parts[-3:]
|
||||
name = "_".join(parts[:-3])
|
||||
mm_table_data.append([name, "-", m, n, k, value])
|
||||
|
||||
log.info("Overview info of inductor aten mms: ")
|
||||
log.info(
|
||||
"{:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format( # noqa: G001
|
||||
"Name", "M", "N", "K", "Count"
|
||||
"{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format( # noqa: G001
|
||||
"Name", "B", "M", "N", "K", "Count"
|
||||
)
|
||||
)
|
||||
log.info("-" * 100)
|
||||
log.info("-" * 130)
|
||||
for row in mm_table_data:
|
||||
log.info("{:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001
|
||||
log.info("-" * 100)
|
||||
log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001
|
||||
log.info("-" * 130)
|
||||
|
||||
# Not strictly necessary, but good to clean up straggling futures
|
||||
# that are unused to reclaim memory.
|
||||
|
|
|
|||
|
|
@ -179,9 +179,11 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None):
|
|||
)
|
||||
|
||||
# below is for getting an overview logging info of inductor mms
|
||||
counters["aten_mm_info"][f"aten.bmm_{m}_{n}_{k}"] += 1
|
||||
batch_size = mat1.get_size()[0] # Extract batch dimension
|
||||
counters["aten_mm_info"][f"aten.bmm_{batch_size}_{m}_{n}_{k}"] += 1
|
||||
log.info(
|
||||
"Tuned aten.bmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
|
||||
"Tuned aten.bmm: batch=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s",
|
||||
batch_size,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
|
@ -241,9 +243,11 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
|
||||
|
||||
# below is for getting an overview logging info of inductor mms
|
||||
counters["aten_mm_info"][f"aten.baddbmm_{m}_{n}_{k}"] += 1
|
||||
batch_size = mat1.get_size()[0]
|
||||
counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1
|
||||
log.info(
|
||||
"Tuned aten.baddbmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, inp=%s, output_layout=%s",
|
||||
"Tuned aten.baddbmm: batch_size=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, inp=%s, output_layout=%s",
|
||||
batch_size,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user