[inductor][decompose-k] make part of template heuristics (#161098)

# why

- enable it to go through commont template heuristics point
- make easier to use in common extension point e.g. lookup table

# what

- break template heuristic into base + triton
- move k_split generation logic into a templateheuristic for decompose k
- register through normal mechanism

- to make testing work, add a context manager to temporarily set
  template heuristics for a template/op to empty (effectively skipping
  it). This is used for decompose k test to disable triton choices

# testing

```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v
```

Differential Revision: [D80670918](https://our.internmc.facebook.com/intern/diff/D80670918)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161098
Approved by: https://github.com/jansel
ghstack dependencies: #161026, #161097
This commit is contained in:
Ruben Rodriguez Buchillon 2025-08-27 18:44:21 -07:00 committed by PyTorch MergeBot
parent f641effe19
commit 496052faf6
8 changed files with 146 additions and 49 deletions

View File

@ -12,7 +12,6 @@ import tempfile
import unittest
from typing import Callable, Optional
from unittest import mock
from unittest.mock import MagicMock
import torch
from torch import multiprocessing as mp, nn
@ -37,6 +36,7 @@ from torch._inductor.select_algorithm import (
TritonTemplate,
TritonTemplateCaller,
)
from torch._inductor.template_heuristics.registry import override_template_heuristics
from torch._inductor.template_heuristics.triton import (
CUDAMMTemplateConfigHeuristic,
GemmConfig,
@ -1271,16 +1271,14 @@ class TestMaxAutotune(TestCase):
# Force only decomposeK choice
with (
mock.patch(
"torch._inductor.kernel.mm.V.choices.get_mm_configs"
) as base_mm_mock,
override_template_heuristics(
device_type=GPU_TYPE,
template_op_pairs=[(torch._inductor.kernel.mm.mm_template.name, "mm")],
),
mock.patch(
"torch._inductor.kernel.mm.use_decompose_k_choice"
) as decompose_mock,
):
mm_configs_mock = MagicMock()
mm_configs_mock.return_value = []
base_mm_mock.return_value = mm_configs_mock
decompose_mock.return_value = True
compiled_f = torch.compile(f)
out, code = run_and_get_code(compiled_f, a, b)

View File

@ -13,7 +13,7 @@ from .kernel_inputs import KernelInputs # noqa: TC001
from .metrics import get_metric_table, is_metric_table_enabled
from .runtime.hints import DeviceProperties, ReductionHint
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
from .template_heuristics.registry import get_template_heuristic
from .template_heuristics import get_template_heuristic
from .template_heuristics.triton import (
BaseConfigHeuristic,
CPUConfigHeuristic,

View File

@ -41,7 +41,6 @@ from ..select_algorithm import (
)
from ..utils import (
_use_cutlass_for_op,
get_k_splits,
get_tma_workspace_arg,
use_aten_gemm_kernels,
use_ck_gemm_template,
@ -763,32 +762,16 @@ def tuned_mm(mat1, mat2, *, layout=None):
**kwargs,
)
from torch._inductor.ir import get_free_symbols
# Only do split-k optimization if K is much larger than m, n and m, n are small
# and if there aren't any unbacked symbols
unbacked_symbols = any(
len(get_free_symbols(itr, unbacked_only=True)) > 0
for itr in (
mat1.get_size(),
mat1.get_stride(),
mat2.get_size(),
mat2.get_stride(),
)
)
if use_decompose_k_choice(m, n, k) and not unbacked_symbols:
k_splits = get_k_splits(m, n, k)
for k_split in k_splits:
if not V.graph.sizevars.statically_known_true(
sympy.Eq(sympy.Mod(k, k_split), 0)
):
continue
if use_decompose_k_choice(m, n, k):
for kwargs in V.choices.get_mm_configs(
kernel_inputs, layout, decompose_k_subgraph_template.name, "mm"
):
decompose_k_subgraph_template.maybe_append_choice(
choices,
input_nodes=kernel_inputs.nodes(),
layout=layout,
k_split=k_split,
**kwargs,
)
if (

View File

@ -0,0 +1,6 @@
# NOTE: add new template heuristics here, so they get imported and registered
# TODO: write a simple glob if there are many heuristics to auto import them in the right order
from . import base, decompose_k, registry, triton
# expose the entry function
from .registry import get_template_heuristic

View File

@ -0,0 +1,26 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Generator
from ..ir import Layout
from ..kernel_inputs import KernelInputs
class TemplateConfigHeuristics:
def get_template_configs(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Get template configs for the given inputs.
This is the main entry point for template-specific logic.
"""
# NOTE: not an abstract class, because that clashed below for the mixin
# functionality. Can be adjusted, but not a high priority
yield from []

View File

@ -0,0 +1,58 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING
import sympy
import torch
from ..ir import get_free_symbols
from ..kernel_inputs import KernelInputs, MMKernelInputs
from ..utils import get_k_splits
from ..virtualized import V
from .base import TemplateConfigHeuristics
from .registry import register_template_heuristic
if TYPE_CHECKING:
from collections.abc import Generator
from ..ir import Layout
@register_template_heuristic(
"decompose_k", "cuda", register=torch.version.hip is None, op_name="mm"
)
class DecomposeKConfigHeuristics(TemplateConfigHeuristics):
def get_template_configs(
self,
kernel_inputs: KernelInputs,
layout: Layout,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Get all the valid k_splits for the given m, n, k.
"""
assert isinstance(kernel_inputs, MMKernelInputs), (
f"{self.__class__.__name__} requires MMKernelInputs"
)
# Check for unbacked symbols - if found, yield nothing
unbacked_symbols = any(
len(get_free_symbols(itr, unbacked_only=True)) > 0
for itr in (
*kernel_inputs.shapes_symbolic(),
*kernel_inputs.strides_symbolic(),
)
)
if unbacked_symbols:
return
m, n, k = kernel_inputs.mnk_symbolic()
k_splits = get_k_splits(m, n, k)
for k_split in k_splits:
if not V.graph.sizevars.statically_known_true(
sympy.Eq(sympy.Mod(k, k_split), 0)
):
continue
yield {"k_split": k_split}

View File

@ -8,13 +8,17 @@ for CUDA vs ROCm based on torch.version.hip.
from __future__ import annotations
import contextlib
import logging
from functools import cache
from typing import Any, Optional, TYPE_CHECKING
from .base import TemplateConfigHeuristics
if TYPE_CHECKING:
from .triton import TemplateConfigHeuristics
from collections.abc import Iterator
# Module-wide registry for template heuristics
_TEMPLATE_HEURISTIC_REGISTRY: dict[tuple[str, ...], type[TemplateConfigHeuristics]] = {}
@ -96,3 +100,42 @@ def get_template_heuristic(
f"Available combinations: {list(_TEMPLATE_HEURISTIC_REGISTRY.keys())}"
)
return heuristic_class()
@contextlib.contextmanager
def override_template_heuristics(
device_type: str,
template_op_pairs: list[tuple[str, str]],
) -> Iterator[None]:
"""
Context manager to temporarily override template heuristics with an empty heuristic.
This is useful for testing purposes, where we want to ensure a specific template/op pair
is not used
Args:
device_type: Device type ("cuda", "cpu", "xpu")
template_op_pairs: List of (template_name, op_name) pairs to override.
"""
# Save original entries to restore later
original_entries = {}
new_keys = []
get_template_heuristic.cache_clear()
try:
for template_name, op_name in template_op_pairs:
assert op_name is not None
key = (device_type, template_name, op_name)
if key in _TEMPLATE_HEURISTIC_REGISTRY:
original_entries[key] = _TEMPLATE_HEURISTIC_REGISTRY[key]
# TemplateConfigHeuristics base class returns no entries
# so we use it for overriding
_TEMPLATE_HEURISTIC_REGISTRY[key] = TemplateConfigHeuristics
new_keys.append(key)
yield
finally:
# Restore original entries or remove if they didn't exist before
for key in new_keys:
_TEMPLATE_HEURISTIC_REGISTRY.pop(key, None)
if key in original_entries:
_TEMPLATE_HEURISTIC_REGISTRY[key] = original_entries[key]
get_template_heuristic.cache_clear()

View File

@ -17,6 +17,7 @@ from .. import config, config as inductor_config
from ..kernel_inputs import KernelInputs, MMKernelInputs
from ..utils import get_backend_num_stages, get_num_sms, TMA_DESCRIPTOR_SIZE
from ..virtualized import V
from .base import TemplateConfigHeuristics
from .registry import register_template_heuristic
@ -1244,24 +1245,6 @@ class MTIAConfigHeuristic(BaseConfigHeuristic):
# Template-specific mixin classes
class TemplateConfigHeuristics:
def get_template_configs(
self,
kernel_inputs: KernelInputs,
layout: Any,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Get template configs for the given inputs.
This is the main entry point for template-specific logic.
"""
# NOTE: not an abstract class, because that clashed below for the mixin
# functionality. Can be adjusted, but not a high priority
yield from {}
class MMTemplateConfigMixin(TemplateConfigHeuristics):
"""
Mixin class that converts config lists to template kwargs.