Commit Graph

7 Commits

Author SHA1 Message Date
PyTorch MergeBot
846963fa9b Revert "[Inductor] addmm + activation function fusion (#158137)"
This reverts commit b9d7de3a09.

Reverted https://github.com/pytorch/pytorch/pull/158137 on behalf of https://github.com/malfet due to Broke inductor torchbench, see 663da17b62/1 ([comment](https://github.com/pytorch/pytorch/pull/158137#issuecomment-3191841298))
2025-08-15 15:34:09 +00:00
AaronWang04
b9d7de3a09 [Inductor] addmm + activation function fusion (#158137)
PR implements a pass in post_grad to fuse activation(add + mm)

This was previously done similarly here #106912 but was reverted for performance reasons. it was replaced with a pass that unfuses the activation and add from addmm/addmm_activation and let inductor handle the fusion.

however since then cuBLAS team has made a lot of perf improvements on this, will update this post with more benchmarks but preliminary benchmark show good results

perf dash board
<img width="3371" height="1240" alt="Screenshot from 2025-08-07 13-41-35" src="https://github.com/user-attachments/assets/d44d6205-b33a-4a20-9f0f-d9db176b3738" />

Relu works with both training and inference but gelu only works with inference mode due to some fundamental limitations since gelu's derivative depends on input and relu's doesnt. don't think this is fixable with the current addmm_activation API

Graph module before and after this pass

Relu(addmm)
```
graph():
    %primals_1 : [num_users=1] = placeholder[target=primals_1]
    %primals_2 : [num_users=2] = placeholder[target=primals_2]
    %primals_3 : [num_users=2] = placeholder[target=primals_3]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {})
    %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%addmm,), kwargs = {})
    %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {})
    %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {})
    return (relu, primals_2, le, permute_1)
graph():
    %primals_1 : [num_users=1] = placeholder[target=primals_1]
    %primals_2 : [num_users=2] = placeholder[target=primals_2]
    %primals_3 : [num_users=2] = placeholder[target=primals_3]
    %_addmm_activation_default : [num_users=2] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {})
    %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%_addmm_activation_default, 0), kwargs = {})
    %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {})
    return (_addmm_activation_default, primals_2, le, permute_1)
```
Gelu (addmm)
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %addmm : [num_users=4] = call_function[target=torch.ops.aten.addmm.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %addmm), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %addmm), kwargs = {})
    %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_1, 0.044715), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%addmm, %mul_2), kwargs = {})
    %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 0.7978845608028654), kwargs = {})
    %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, 0.5), kwargs = {})
    %tanh : [num_users=1] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_3,), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh, 1), kwargs = {})
    %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_4, %add_1), kwargs = {})
    return (mul_5,)
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %_addmm_activation_default : [num_users=1] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {use_gelu: True})
    return (_addmm_activation_default,)
```

Benchmark setup:
NGC pytorch 25.06 container
cublas version: 12.9.1.4
torch.compile ran with dynamic = False and max_autotune

H100
```
Testing with M=1024, N=1024, K=1024, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):	 0.0107 ms
Average Time per Iteration (torch compile):	 0.0296 ms

============================================================
Testing with M=2048, N=2048, K=2048, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):	 0.0262 ms
Average Time per Iteration (torch compile):	 0.0327 ms

============================================================
Testing with M=4096, N=4096, K=4096, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):	 0.1763 ms
Average Time per Iteration (torch compile):	 0.2457 ms

============================================================
Testing with M=8192, N=8192, K=8192, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):	 1.5280 ms
Average Time per Iteration (torch compile):	 1.9437 ms
```

A100
```
############################################################
Testing with dtype: float16
############################################################

============================================================
Testing with M=1024, N=1024, K=1024, dtype=float16
============================================================
Average Time per Iteration (cublas):	 0.0313 ms
Average Time per Iteration (torch compile):	 0.0643 ms

============================================================
Testing with M=2048, N=2048, K=2048, dtype=float16
============================================================
Average Time per Iteration (cublas):	 0.1149 ms
Average Time per Iteration (torch compile):	 0.1255 ms

============================================================
Testing with M=4096, N=4096, K=4096, dtype=float16
============================================================
Average Time per Iteration (cublas):	 0.6297 ms
Average Time per Iteration (torch compile):	 0.7547 ms

============================================================
Testing with M=8192, N=8192, K=8192, dtype=float16
============================================================
Average Time per Iteration (cublas):	 4.3821 ms
Average Time per Iteration (torch compile):	 5.0740 ms
```

Script
```py
import torch
torch.manual_seed(0)

warmup, numrun= 10, 100

sizes = [1024, 2048, 4096, 8192]
dtypes = [torch.float16, torch.bfloat16, torch.float32]

device = torch.device("cuda")

for dtype in dtypes:
    dtype_name = str(dtype).split('.')[-1]
    print(f"\n{'#'*60}")
    print(f"Testing with dtype: {dtype_name}")
    print(f"{'#'*60}")

    for size in sizes:
        M, N, K = size, size, size
        print(f"\n{'='*60}")
        print(f"Testing with M={M}, N={N}, K={K}, dtype={dtype_name}")
        print(f"{'='*60}")

        A = torch.randn(M, K, device=device, dtype=dtype)
        B = torch.randn(K, N, device=device, dtype=dtype)
        C = torch.randn(M, device=device, dtype=dtype)

        def func1():
            return torch._addmm_activation(C, A, B, use_gelu=True)

        def func2():
            return torch.nn.functional.gelu(torch.add(C, torch.mm(A, B)), approximate="tanh")

        func2_compiled = torch.compile(
            func2,
            dynamic=False,
            options={
                "force_disable_caches": True,
                "max_autotune": True,
                "max_autotune_gemm": True,
                "max_autotune_gemm_backends": "TRITON",
                "autotune_fallback_to_aten": False,
            }
        )

        for _ in range(warmup): func1()
        torch.cuda.synchronize(device=device)

        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        total_time_ms = 0.0
        start_event.record()
        for _ in range(numrun): func1()
        end_event.record()
        torch.cuda.synchronize(device=device)
        total_time_ms += start_event.elapsed_time(end_event)
        avg_time_ms = total_time_ms / numrun

        print(f"Average Time per Iteration (cublas):\t {avg_time_ms:.4f} ms")

        for _ in range(warmup): func2_compiled()
        torch.cuda.synchronize(device=device)

        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        total_time_ms = 0.0
        start_event.record()
        for _ in range(numrun): func2_compiled()
        end_event.record()
        torch.cuda.synchronize(device=device)
        total_time_ms += start_event.elapsed_time(end_event)
        avg_time_ms = total_time_ms / numrun

        print(f"Average Time per Iteration (torch compile):\t {avg_time_ms:.4f} ms")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158137
Approved by: https://github.com/eellison
2025-08-14 20:41:38 +00:00
Xuehai Pan
93a33bf3ac [BE] update type annotations for basic utilities in torch/__init__.py (#129001)
Changes:

1. Make some arguments positional-only as we only support Python 3.8+
2. Clean up `torch.typename(obj)` implementation.
3. Update type annotations., especially `is_tensor()` and `is_masked_tensor()` using `TypeGuard`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129001
Approved by: https://github.com/malfet
2024-06-24 18:04:38 +00:00
PyTorch MergeBot
cb4919344a Revert "[BE] update type annotations for basic utilities in torch/__init__.py (#129001)"
This reverts commit e53d959028.

Reverted https://github.com/pytorch/pytorch/pull/129001 on behalf of https://github.com/XuehaiPan due to lint failure ([comment](https://github.com/pytorch/pytorch/pull/129001#issuecomment-2186944549))
2024-06-24 16:18:43 +00:00
Xuehai Pan
e53d959028 [BE] update type annotations for basic utilities in torch/__init__.py (#129001)
Changes:

1. Make some arguments positional-only as we only support Python 3.8+
2. Clean up `torch.typename(obj)` implementation.
3. Update type annotations., especially `is_tensor()` and `is_masked_tensor()` using `TypeGuard`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129001
Approved by: https://github.com/malfet
2024-06-24 14:35:41 +00:00
Xuehai Pan
b697808056 [BE][Easy] eliminate relative import in torchgen (#128872)
Fix generated by:

```bash
ruff check --config 'lint.flake8-tidy-imports.ban-relative-imports="all"' --fix --select=TID $(fd '.pyi?$' torchgen)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128872
Approved by: https://github.com/zou3519
2024-06-21 14:11:46 +00:00
Aaron Orenstein
4044e93a51 Add mm_pattern and bmm_pattern to serialized_patterns (#121313)
Make it easier to serialize patterns by adding `pattern_matcher.gen_register_replacement()` which is like `pattern_matcher.register_replacement()` but also requires the replacement to be precompiled.

To precompile patterns (and save to disk) run:
```
torchgen/fuse_attention_patterns/gen_attention_patterns.py
```

- Updated the sfdp patterns to use `gen_register_replacement`.
- Add serialized patterns for mm_pattern and bmm_pattern (The 'misc' patterns don't serialize cleanly so can't be added).
- Updated the testing so it checked the round-trip patterns match and not just that it serialized the same way.
- Checking that the patterns round-trip properly found that the `users` field wasn't being serialized properly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121313
Approved by: https://github.com/eellison
2024-04-09 19:42:19 +00:00