PR implements a pass in post_grad to fuse activation(add + mm)
This was previously done similarly here #106912 but was reverted for performance reasons. it was replaced with a pass that unfuses the activation and add from addmm/addmm_activation and let inductor handle the fusion.
however since then cuBLAS team has made a lot of perf improvements on this, will update this post with more benchmarks but preliminary benchmark show good results
perf dash board
<img width="3371" height="1240" alt="Screenshot from 2025-08-07 13-41-35" src="https://github.com/user-attachments/assets/d44d6205-b33a-4a20-9f0f-d9db176b3738" />
Relu works with both training and inference but gelu only works with inference mode due to some fundamental limitations since gelu's derivative depends on input and relu's doesnt. don't think this is fixable with the current addmm_activation API
Graph module before and after this pass
Relu(addmm)
```
graph():
%primals_1 : [num_users=1] = placeholder[target=primals_1]
%primals_2 : [num_users=2] = placeholder[target=primals_2]
%primals_3 : [num_users=2] = placeholder[target=primals_3]
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {})
%relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%addmm,), kwargs = {})
%le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {})
%permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {})
return (relu, primals_2, le, permute_1)
graph():
%primals_1 : [num_users=1] = placeholder[target=primals_1]
%primals_2 : [num_users=2] = placeholder[target=primals_2]
%primals_3 : [num_users=2] = placeholder[target=primals_3]
%_addmm_activation_default : [num_users=2] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {})
%le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%_addmm_activation_default, 0), kwargs = {})
%permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {})
return (_addmm_activation_default, primals_2, le, permute_1)
```
Gelu (addmm)
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%addmm : [num_users=4] = call_function[target=torch.ops.aten.addmm.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %addmm), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %addmm), kwargs = {})
%mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_1, 0.044715), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%addmm, %mul_2), kwargs = {})
%mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 0.7978845608028654), kwargs = {})
%mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, 0.5), kwargs = {})
%tanh : [num_users=1] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_3,), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh, 1), kwargs = {})
%mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_4, %add_1), kwargs = {})
return (mul_5,)
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%_addmm_activation_default : [num_users=1] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {use_gelu: True})
return (_addmm_activation_default,)
```
Benchmark setup:
NGC pytorch 25.06 container
cublas version: 12.9.1.4
torch.compile ran with dynamic = False and max_autotune
H100
```
Testing with M=1024, N=1024, K=1024, dtype=bfloat16
============================================================
Average Time per Iteration (cublas): 0.0107 ms
Average Time per Iteration (torch compile): 0.0296 ms
============================================================
Testing with M=2048, N=2048, K=2048, dtype=bfloat16
============================================================
Average Time per Iteration (cublas): 0.0262 ms
Average Time per Iteration (torch compile): 0.0327 ms
============================================================
Testing with M=4096, N=4096, K=4096, dtype=bfloat16
============================================================
Average Time per Iteration (cublas): 0.1763 ms
Average Time per Iteration (torch compile): 0.2457 ms
============================================================
Testing with M=8192, N=8192, K=8192, dtype=bfloat16
============================================================
Average Time per Iteration (cublas): 1.5280 ms
Average Time per Iteration (torch compile): 1.9437 ms
```
A100
```
############################################################
Testing with dtype: float16
############################################################
============================================================
Testing with M=1024, N=1024, K=1024, dtype=float16
============================================================
Average Time per Iteration (cublas): 0.0313 ms
Average Time per Iteration (torch compile): 0.0643 ms
============================================================
Testing with M=2048, N=2048, K=2048, dtype=float16
============================================================
Average Time per Iteration (cublas): 0.1149 ms
Average Time per Iteration (torch compile): 0.1255 ms
============================================================
Testing with M=4096, N=4096, K=4096, dtype=float16
============================================================
Average Time per Iteration (cublas): 0.6297 ms
Average Time per Iteration (torch compile): 0.7547 ms
============================================================
Testing with M=8192, N=8192, K=8192, dtype=float16
============================================================
Average Time per Iteration (cublas): 4.3821 ms
Average Time per Iteration (torch compile): 5.0740 ms
```
Script
```py
import torch
torch.manual_seed(0)
warmup, numrun= 10, 100
sizes = [1024, 2048, 4096, 8192]
dtypes = [torch.float16, torch.bfloat16, torch.float32]
device = torch.device("cuda")
for dtype in dtypes:
dtype_name = str(dtype).split('.')[-1]
print(f"\n{'#'*60}")
print(f"Testing with dtype: {dtype_name}")
print(f"{'#'*60}")
for size in sizes:
M, N, K = size, size, size
print(f"\n{'='*60}")
print(f"Testing with M={M}, N={N}, K={K}, dtype={dtype_name}")
print(f"{'='*60}")
A = torch.randn(M, K, device=device, dtype=dtype)
B = torch.randn(K, N, device=device, dtype=dtype)
C = torch.randn(M, device=device, dtype=dtype)
def func1():
return torch._addmm_activation(C, A, B, use_gelu=True)
def func2():
return torch.nn.functional.gelu(torch.add(C, torch.mm(A, B)), approximate="tanh")
func2_compiled = torch.compile(
func2,
dynamic=False,
options={
"force_disable_caches": True,
"max_autotune": True,
"max_autotune_gemm": True,
"max_autotune_gemm_backends": "TRITON",
"autotune_fallback_to_aten": False,
}
)
for _ in range(warmup): func1()
torch.cuda.synchronize(device=device)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
total_time_ms = 0.0
start_event.record()
for _ in range(numrun): func1()
end_event.record()
torch.cuda.synchronize(device=device)
total_time_ms += start_event.elapsed_time(end_event)
avg_time_ms = total_time_ms / numrun
print(f"Average Time per Iteration (cublas):\t {avg_time_ms:.4f} ms")
for _ in range(warmup): func2_compiled()
torch.cuda.synchronize(device=device)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
total_time_ms = 0.0
start_event.record()
for _ in range(numrun): func2_compiled()
end_event.record()
torch.cuda.synchronize(device=device)
total_time_ms += start_event.elapsed_time(end_event)
avg_time_ms = total_time_ms / numrun
print(f"Average Time per Iteration (torch compile):\t {avg_time_ms:.4f} ms")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158137
Approved by: https://github.com/eellison
Changes:
1. Make some arguments positional-only as we only support Python 3.8+
2. Clean up `torch.typename(obj)` implementation.
3. Update type annotations., especially `is_tensor()` and `is_masked_tensor()` using `TypeGuard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129001
Approved by: https://github.com/malfet
Changes:
1. Make some arguments positional-only as we only support Python 3.8+
2. Clean up `torch.typename(obj)` implementation.
3. Update type annotations., especially `is_tensor()` and `is_masked_tensor()` using `TypeGuard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129001
Approved by: https://github.com/malfet
Make it easier to serialize patterns by adding `pattern_matcher.gen_register_replacement()` which is like `pattern_matcher.register_replacement()` but also requires the replacement to be precompiled.
To precompile patterns (and save to disk) run:
```
torchgen/fuse_attention_patterns/gen_attention_patterns.py
```
- Updated the sfdp patterns to use `gen_register_replacement`.
- Add serialized patterns for mm_pattern and bmm_pattern (The 'misc' patterns don't serialize cleanly so can't be added).
- Updated the testing so it checked the round-trip patterns match and not just that it serialized the same way.
- Checking that the patterns round-trip properly found that the `users` field wasn't being serialized properly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121313
Approved by: https://github.com/eellison