mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Add Additional Configs for persistent+TMA version of Triton mm and addmm (#150587)
Summary: This PR introduces additional autotuning configurations for the persistent+TMA version of Triton `mm` and `addmm` operations. The new configurations are as follows: * `(128, 128, 64, 5, 8)` * `(256, 128, 64, 4, 8)` * `(128, 128, 64, 5, 4)` These configurations were selected based on exhaustive autotuning performed on commonly used shapes from an internal foundational model. While these new configs are generally more performant across the board, we see notable gains a few specific cases: * In scenarios where `n >> m, k`, the configurations `(128, 128, 64, 5, 8)` and `(256, 128, 64, 4, 8)` tend to produce an additional 5-10% speedup over the aten baseline compared to the original configurations. * Similarly, the configuration `(128, 128, 64, 5, 4)` yields approximately an 8% improvement in scenarios where k >> m, n. These enhancements are expected to provide performance benefits across diverse use cases, particularly when compared to the original set of configurations. Test Plan: contbuild & OSS CI Reviewers: paulzhan Pull Request resolved: https://github.com/pytorch/pytorch/pull/150587 Approved by: https://github.com/PaulZhang12, https://github.com/drisspg, https://github.com/eellison
This commit is contained in:
parent
4d78e19365
commit
99aeee2c5f
|
|
@ -170,6 +170,9 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
|||
GemmConfig(128, 128, 128, 3, 8),
|
||||
GemmConfig(128, 128, 128, 3, 4),
|
||||
GemmConfig(128, 128, 64, 4, 8),
|
||||
GemmConfig(128, 128, 64, 5, 8),
|
||||
GemmConfig(256, 128, 64, 4, 8),
|
||||
GemmConfig(128, 128, 64, 5, 4),
|
||||
]
|
||||
|
||||
self.scaled_mm_configs: list[BaseConfig] = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user