[Inductor] Add Additional Configs for persistent+TMA version of Triton mm and addmm (#150587)

Summary:
This PR introduces additional autotuning configurations for the persistent+TMA version of Triton `mm` and `addmm` operations. The new configurations are as follows:
* `(128, 128, 64, 5, 8)`
* `(256, 128, 64, 4, 8)`
* `(128, 128, 64, 5, 4)`

These configurations were selected based on exhaustive autotuning performed on commonly used shapes from an internal foundational model.

While these new configs are generally more performant across the board, we see notable gains a few specific cases:
* In scenarios where `n >> m, k`, the configurations `(128, 128, 64, 5, 8)` and `(256, 128, 64, 4, 8)` tend to produce an additional 5-10% speedup over the aten baseline compared to the original configurations.
* Similarly, the configuration `(128, 128, 64, 5, 4)` yields approximately an 8% improvement in scenarios where k >> m, n.

These enhancements are expected to provide performance benefits across diverse use cases, particularly when compared to the original set of configurations.

Test Plan:
contbuild & OSS CI

Reviewers: paulzhan

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150587
Approved by: https://github.com/PaulZhang12, https://github.com/drisspg, https://github.com/eellison
This commit is contained in:
Nikhil Anil Patel 2025-04-21 19:30:57 +00:00 committed by PyTorch MergeBot
parent 4d78e19365
commit 99aeee2c5f

View File

@ -170,6 +170,9 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
GemmConfig(128, 128, 128, 3, 8),
GemmConfig(128, 128, 128, 3, 4),
GemmConfig(128, 128, 64, 4, 8),
GemmConfig(128, 128, 64, 5, 8),
GemmConfig(256, 128, 64, 4, 8),
GemmConfig(128, 128, 64, 5, 4),
]
self.scaled_mm_configs: list[BaseConfig] = [