mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improves the GEMM overview logging in PyTorch Inductor to properly display batch size information for batched matrix operations like `torch.bmm` and `torch.baddbmm`.
**Fixes #155307**
## Problem
The current GEMM logging for `torch.bmm` shows:
```python
# Repro
import os
os.environ["TORCH_LOGS"] = "inductor"
import torch
M, N, K = 1024, 1024, 1024
dtype = torch.bfloat16
A = torch.randn(10, M, K, device="cuda", dtype=dtype)
B = torch.randn(10, K, N, device="cuda", dtype=dtype)
compiled_model = torch.compile(torch.bmm, fullgraph=True)
_ = compiled_model(A, B)
```
**Before:**
```
Name | M | N | K | Count
----------------------------------------------------------------------------------------------------
aten.bmm | 1024 | 1024 | 1024 | 1
----------------------------------------------------------------------------------------------------
```
The batch size (10) is missing from the logs, making it unclear what the actual operation dimensions were.
## Solution
**After:**
```
Name | B | M | N | K | Count
----------------------------------------------------------------------------------------------------------------------------------
aten.bmm | 10 | 1024 | 1024 | 1024 | 1
aten.mm | - | 1024 | 1024 | 1024 | 2
----------------------------------------------------------------------------------------------------------------------------------
```
## Changes Made
### 1. Enhanced Parsing Logic in compile_fx.py
- Detects batched operations by checking if operation name ends with `'bmm'` or `'baddbmm'`
- For batched operations: takes last 4 parts as `batch, m, n, k`
- For non-batched operations: takes last 3 parts as `m, n, k`
- **Dedicated "B" column**: Added separate column for batch size instead of embedding in operation name
- Shows batch size for batched operations, shows "-" for non-batched operations
### 2. Updated All MM Operations for Consistency
- **bmm.py**:
- Extract batch size from `mat1.get_size()[0]` for both `tuned_bmm` and `tuned_baddbmm`
- Use positional counter keys: `aten.bmm_{batch_size}_{m}_{n}_{k}`
- Enhanced log messages to include batch size information
- **mm.py**: Updated counter keys for consistency:
- `aten.mm_{m}_{n}_{k}` (no batch dimension)
- `aten.addmm_{m}_{n}_{k}` (no batch dimension)
- `aten._int_mm_{m}_{n}_{k}` (no batch dimension)
- `aten._scaled_mm.default_{m}_{n}_{k}` (no batch dimension)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155544
Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
|
||
|---|---|---|
| .. | ||
| _awaits | ||
| _C | ||
| _C_flatbuffer | ||
| _custom_op | ||
| _decomp | ||
| _dispatch | ||
| _dynamo | ||
| _export | ||
| _functorch | ||
| _higher_order_ops | ||
| _inductor | ||
| _lazy | ||
| _library | ||
| _logging | ||
| _numpy | ||
| _prims | ||
| _prims_common | ||
| _refs | ||
| _strobelight | ||
| _subclasses | ||
| _vendor | ||
| accelerator | ||
| amp | ||
| ao | ||
| autograd | ||
| backends | ||
| compiler | ||
| contrib | ||
| cpu | ||
| csrc | ||
| cuda | ||
| distributed | ||
| distributions | ||
| export | ||
| fft | ||
| func | ||
| futures | ||
| fx | ||
| jit | ||
| legacy | ||
| lib | ||
| linalg | ||
| masked | ||
| monitor | ||
| mps | ||
| mtia | ||
| multiprocessing | ||
| nativert | ||
| nested | ||
| nn | ||
| onnx | ||
| optim | ||
| package | ||
| profiler | ||
| quantization | ||
| signal | ||
| sparse | ||
| special | ||
| standalone/macros | ||
| testing | ||
| utils | ||
| xpu | ||
| __config__.py | ||
| __future__.py | ||
| __init__.py | ||
| _appdirs.py | ||
| _classes.py | ||
| _compile.py | ||
| _custom_ops.py | ||
| _deploy.py | ||
| _environment.py | ||
| _guards.py | ||
| _jit_internal.py | ||
| _linalg_utils.py | ||
| _lobpcg.py | ||
| _lowrank.py | ||
| _meta_registrations.py | ||
| _namedtensor_internals.py | ||
| _ops.py | ||
| _python_dispatcher.py | ||
| _size_docs.py | ||
| _sources.py | ||
| _storage_docs.py | ||
| _streambase.py | ||
| _tensor_docs.py | ||
| _tensor_str.py | ||
| _tensor.py | ||
| _thread_safe_fork.py | ||
| _torch_docs.py | ||
| _utils_internal.py | ||
| _utils.py | ||
| _VF.py | ||
| _vmap_internals.py | ||
| _weights_only_unpickler.py | ||
| CMakeLists.txt | ||
| custom_class_detail.h | ||
| custom_class.h | ||
| extension.h | ||
| functional.py | ||
| header_only_apis.txt | ||
| hub.py | ||
| library.h | ||
| library.py | ||
| overrides.py | ||
| py.typed | ||
| quasirandom.py | ||
| random.py | ||
| return_types.py | ||
| script.h | ||
| serialization.py | ||
| storage.py | ||
| torch_version.py | ||
| types.py | ||
| version.py.tpl | ||