mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor][decompose-k] make part of template heuristics (#161098)
# why - enable it to go through commont template heuristics point - make easier to use in common extension point e.g. lookup table # what - break template heuristic into base + triton - move k_split generation logic into a templateheuristic for decompose k - register through normal mechanism - to make testing work, add a context manager to temporarily set template heuristics for a template/op to empty (effectively skipping it). This is used for decompose k test to disable triton choices # testing ``` python3 -bb -m pytest test/inductor/test_max_autotune.py -v ``` Differential Revision: [D80670918](https://our.internmc.facebook.com/intern/diff/D80670918) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161098 Approved by: https://github.com/jansel ghstack dependencies: #161026, #161097
This commit is contained in:
parent
f641effe19
commit
496052faf6
|
|
@ -12,7 +12,6 @@ import tempfile
|
|||
import unittest
|
||||
from typing import Callable, Optional
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
from torch import multiprocessing as mp, nn
|
||||
|
|
@ -37,6 +36,7 @@ from torch._inductor.select_algorithm import (
|
|||
TritonTemplate,
|
||||
TritonTemplateCaller,
|
||||
)
|
||||
from torch._inductor.template_heuristics.registry import override_template_heuristics
|
||||
from torch._inductor.template_heuristics.triton import (
|
||||
CUDAMMTemplateConfigHeuristic,
|
||||
GemmConfig,
|
||||
|
|
@ -1271,16 +1271,14 @@ class TestMaxAutotune(TestCase):
|
|||
|
||||
# Force only decomposeK choice
|
||||
with (
|
||||
mock.patch(
|
||||
"torch._inductor.kernel.mm.V.choices.get_mm_configs"
|
||||
) as base_mm_mock,
|
||||
override_template_heuristics(
|
||||
device_type=GPU_TYPE,
|
||||
template_op_pairs=[(torch._inductor.kernel.mm.mm_template.name, "mm")],
|
||||
),
|
||||
mock.patch(
|
||||
"torch._inductor.kernel.mm.use_decompose_k_choice"
|
||||
) as decompose_mock,
|
||||
):
|
||||
mm_configs_mock = MagicMock()
|
||||
mm_configs_mock.return_value = []
|
||||
base_mm_mock.return_value = mm_configs_mock
|
||||
decompose_mock.return_value = True
|
||||
compiled_f = torch.compile(f)
|
||||
out, code = run_and_get_code(compiled_f, a, b)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from .kernel_inputs import KernelInputs # noqa: TC001
|
|||
from .metrics import get_metric_table, is_metric_table_enabled
|
||||
from .runtime.hints import DeviceProperties, ReductionHint
|
||||
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
|
||||
from .template_heuristics.registry import get_template_heuristic
|
||||
from .template_heuristics import get_template_heuristic
|
||||
from .template_heuristics.triton import (
|
||||
BaseConfigHeuristic,
|
||||
CPUConfigHeuristic,
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ from ..select_algorithm import (
|
|||
)
|
||||
from ..utils import (
|
||||
_use_cutlass_for_op,
|
||||
get_k_splits,
|
||||
get_tma_workspace_arg,
|
||||
use_aten_gemm_kernels,
|
||||
use_ck_gemm_template,
|
||||
|
|
@ -763,32 +762,16 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
from torch._inductor.ir import get_free_symbols
|
||||
|
||||
# Only do split-k optimization if K is much larger than m, n and m, n are small
|
||||
# and if there aren't any unbacked symbols
|
||||
unbacked_symbols = any(
|
||||
len(get_free_symbols(itr, unbacked_only=True)) > 0
|
||||
for itr in (
|
||||
mat1.get_size(),
|
||||
mat1.get_stride(),
|
||||
mat2.get_size(),
|
||||
mat2.get_stride(),
|
||||
)
|
||||
)
|
||||
if use_decompose_k_choice(m, n, k) and not unbacked_symbols:
|
||||
k_splits = get_k_splits(m, n, k)
|
||||
for k_split in k_splits:
|
||||
if not V.graph.sizevars.statically_known_true(
|
||||
sympy.Eq(sympy.Mod(k, k_split), 0)
|
||||
):
|
||||
continue
|
||||
|
||||
if use_decompose_k_choice(m, n, k):
|
||||
for kwargs in V.choices.get_mm_configs(
|
||||
kernel_inputs, layout, decompose_k_subgraph_template.name, "mm"
|
||||
):
|
||||
decompose_k_subgraph_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
layout=layout,
|
||||
k_split=k_split,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
# NOTE: add new template heuristics here, so they get imported and registered
|
||||
# TODO: write a simple glob if there are many heuristics to auto import them in the right order
|
||||
from . import base, decompose_k, registry, triton
|
||||
|
||||
# expose the entry function
|
||||
from .registry import get_template_heuristic
|
||||
26
torch/_inductor/template_heuristics/base.py
Normal file
26
torch/_inductor/template_heuristics/base.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from ..ir import Layout
|
||||
from ..kernel_inputs import KernelInputs
|
||||
|
||||
|
||||
class TemplateConfigHeuristics:
|
||||
def get_template_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Layout,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Get template configs for the given inputs.
|
||||
This is the main entry point for template-specific logic.
|
||||
"""
|
||||
# NOTE: not an abstract class, because that clashed below for the mixin
|
||||
# functionality. Can be adjusted, but not a high priority
|
||||
yield from []
|
||||
58
torch/_inductor/template_heuristics/decompose_k.py
Normal file
58
torch/_inductor/template_heuristics/decompose_k.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from ..ir import get_free_symbols
|
||||
from ..kernel_inputs import KernelInputs, MMKernelInputs
|
||||
from ..utils import get_k_splits
|
||||
from ..virtualized import V
|
||||
from .base import TemplateConfigHeuristics
|
||||
from .registry import register_template_heuristic
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from ..ir import Layout
|
||||
|
||||
|
||||
@register_template_heuristic(
|
||||
"decompose_k", "cuda", register=torch.version.hip is None, op_name="mm"
|
||||
)
|
||||
class DecomposeKConfigHeuristics(TemplateConfigHeuristics):
|
||||
def get_template_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Layout,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Get all the valid k_splits for the given m, n, k.
|
||||
"""
|
||||
assert isinstance(kernel_inputs, MMKernelInputs), (
|
||||
f"{self.__class__.__name__} requires MMKernelInputs"
|
||||
)
|
||||
|
||||
# Check for unbacked symbols - if found, yield nothing
|
||||
unbacked_symbols = any(
|
||||
len(get_free_symbols(itr, unbacked_only=True)) > 0
|
||||
for itr in (
|
||||
*kernel_inputs.shapes_symbolic(),
|
||||
*kernel_inputs.strides_symbolic(),
|
||||
)
|
||||
)
|
||||
if unbacked_symbols:
|
||||
return
|
||||
|
||||
m, n, k = kernel_inputs.mnk_symbolic()
|
||||
k_splits = get_k_splits(m, n, k)
|
||||
for k_split in k_splits:
|
||||
if not V.graph.sizevars.statically_known_true(
|
||||
sympy.Eq(sympy.Mod(k, k_split), 0)
|
||||
):
|
||||
continue
|
||||
yield {"k_split": k_split}
|
||||
|
|
@ -8,13 +8,17 @@ for CUDA vs ROCm based on torch.version.hip.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
from .base import TemplateConfigHeuristics
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .triton import TemplateConfigHeuristics
|
||||
from collections.abc import Iterator
|
||||
|
||||
|
||||
# Module-wide registry for template heuristics
|
||||
_TEMPLATE_HEURISTIC_REGISTRY: dict[tuple[str, ...], type[TemplateConfigHeuristics]] = {}
|
||||
|
|
@ -96,3 +100,42 @@ def get_template_heuristic(
|
|||
f"Available combinations: {list(_TEMPLATE_HEURISTIC_REGISTRY.keys())}"
|
||||
)
|
||||
return heuristic_class()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def override_template_heuristics(
|
||||
device_type: str,
|
||||
template_op_pairs: list[tuple[str, str]],
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
Context manager to temporarily override template heuristics with an empty heuristic.
|
||||
|
||||
This is useful for testing purposes, where we want to ensure a specific template/op pair
|
||||
is not used
|
||||
|
||||
Args:
|
||||
device_type: Device type ("cuda", "cpu", "xpu")
|
||||
template_op_pairs: List of (template_name, op_name) pairs to override.
|
||||
"""
|
||||
# Save original entries to restore later
|
||||
original_entries = {}
|
||||
new_keys = []
|
||||
get_template_heuristic.cache_clear()
|
||||
try:
|
||||
for template_name, op_name in template_op_pairs:
|
||||
assert op_name is not None
|
||||
key = (device_type, template_name, op_name)
|
||||
if key in _TEMPLATE_HEURISTIC_REGISTRY:
|
||||
original_entries[key] = _TEMPLATE_HEURISTIC_REGISTRY[key]
|
||||
# TemplateConfigHeuristics base class returns no entries
|
||||
# so we use it for overriding
|
||||
_TEMPLATE_HEURISTIC_REGISTRY[key] = TemplateConfigHeuristics
|
||||
new_keys.append(key)
|
||||
yield
|
||||
finally:
|
||||
# Restore original entries or remove if they didn't exist before
|
||||
for key in new_keys:
|
||||
_TEMPLATE_HEURISTIC_REGISTRY.pop(key, None)
|
||||
if key in original_entries:
|
||||
_TEMPLATE_HEURISTIC_REGISTRY[key] = original_entries[key]
|
||||
get_template_heuristic.cache_clear()
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from .. import config, config as inductor_config
|
|||
from ..kernel_inputs import KernelInputs, MMKernelInputs
|
||||
from ..utils import get_backend_num_stages, get_num_sms, TMA_DESCRIPTOR_SIZE
|
||||
from ..virtualized import V
|
||||
from .base import TemplateConfigHeuristics
|
||||
from .registry import register_template_heuristic
|
||||
|
||||
|
||||
|
|
@ -1244,24 +1245,6 @@ class MTIAConfigHeuristic(BaseConfigHeuristic):
|
|||
|
||||
|
||||
# Template-specific mixin classes
|
||||
|
||||
|
||||
class TemplateConfigHeuristics:
|
||||
def get_template_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Get template configs for the given inputs.
|
||||
This is the main entry point for template-specific logic.
|
||||
"""
|
||||
# NOTE: not an abstract class, because that clashed below for the mixin
|
||||
# functionality. Can be adjusted, but not a high priority
|
||||
yield from {}
|
||||
|
||||
|
||||
class MMTemplateConfigMixin(TemplateConfigHeuristics):
|
||||
"""
|
||||
Mixin class that converts config lists to template kwargs.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user