diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py index fe6476f317f..84bd26ed1dd 100644 --- a/torch/_inductor/template_heuristics.py +++ b/torch/_inductor/template_heuristics.py @@ -170,6 +170,9 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton): GemmConfig(128, 128, 128, 3, 8), GemmConfig(128, 128, 128, 3, 4), GemmConfig(128, 128, 64, 4, 8), + GemmConfig(128, 128, 64, 5, 8), + GemmConfig(256, 128, 64, 4, 8), + GemmConfig(128, 128, 64, 5, 4), ] self.scaled_mm_configs: list[BaseConfig] = [