diff --git a/test/inductor/test_lookup_table.py b/test/inductor/test_lookup_table.py new file mode 100644 index 00000000000..250a8222678 --- /dev/null +++ b/test/inductor/test_lookup_table.py @@ -0,0 +1,1063 @@ +# Owner(s): ["module: inductor"] +import re +import unittest +from functools import partial +from typing import Any, Optional, Union +from unittest.mock import patch + +import torch +import torch.nn as nn +from torch._inductor import config as inductor_config +from torch._inductor.choices import InductorChoices +from torch._inductor.kernel_inputs import MMKernelInputs +from torch._inductor.lookup_table.choices import LookupTableChoices +from torch._inductor.select_algorithm import ( + add_preprocessing_fn, + clear_preprocessing_fns, + ExternKernelCaller, + TritonTemplateCaller, +) +from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import fresh_cache, get_num_sms, TMA_DESCRIPTOR_SIZE +from torch._inductor.virtualized import V +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + TEST_WITH_ROCM, +) +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON, HAS_GPU +from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device + + +class MockTensorNode: + """Mock input node that wraps a real tensor for testing""" + + def __init__(self, tensor: torch.Tensor): + self.tensor = tensor + + def get_device(self) -> torch.device: + return self.tensor.device + + def get_dtype(self) -> torch.dtype: + return self.tensor.dtype + + def get_size(self) -> tuple[int, ...]: + return tuple(self.tensor.shape) + + def get_stride(self) -> tuple[int, ...]: + return tuple(self.tensor.stride()) + + +class MockMMKernelInputs(MMKernelInputs): + """Mock MMKernelInputs that subclasses the real class and uses real tensors""" + + def __init__( + self, + tensors: list[torch.Tensor], + scalars: Optional[dict[str, Union[float, int]]] = None, + mat1_idx: int = -2, + mat2_idx: int = -1, + ): + """Initialize with real tensors, creating mock nodes for the base class""" + mock_nodes = [MockTensorNode(t) for t in tensors] + super().__init__(mock_nodes, scalars, mat1_idx=mat1_idx, mat2_idx=mat2_idx) + self.tensors = tensors # Keep reference to original tensors + + def shapes_hinted(self) -> tuple[tuple[int, ...], ...]: + """Delegate to symbolic since real tensors already have int shapes""" + return self.shapes_symbolic() + + def strides_hinted(self) -> tuple[tuple[int, ...], ...]: + """Delegate to symbolic since real tensors already have int strides""" + return self.strides_symbolic() # pyre-ignore + + def mnk_hinted(self) -> tuple[int, int, int]: + """Delegate to symbolic since real tensors already have int dimensions""" + return self.mnk_symbolic() # pyre-ignore + + @property + def device_type(self) -> Optional[str]: + return self.tensors[0].device.type + + +class BaseLookupTableTest(TestCase): + """Base class for lookup table tests with common setup and utilities""" + + def setUp(self): + super().setUp() + self.original_table = inductor_config.lookup_table.table + self.original_max_autotune = getattr(inductor_config, "max_autotune", False) + inductor_config.max_autotune = True + # Set the lookup table choices handler + V.set_choices_handler(LookupTableChoices()) + + def tearDown(self): + inductor_config.lookup_table.table = self.original_table + inductor_config.max_autotune = self.original_max_autotune + # Restore original choices handler + V.set_choices_handler(InductorChoices()) + super().tearDown() + + def create_mock_mm_kernel_inputs( + self, + shapes: Optional[list[tuple[int, ...]]] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.float32, + scalars: Optional[dict[str, Union[float, int]]] = None, + ) -> MockMMKernelInputs: + """Create MockMMKernelInputs with real tensors""" + if shapes is None: + shapes = [(128, 128), (128, 128)] # Default MM shapes + + tensors = [] + for shape in shapes: + # Create a real tensor with the specified shape, device, and dtype + tensor = torch.randn(shape, device=device, dtype=dtype) + tensors.append(tensor) + + return MockMMKernelInputs(tensors, scalars) + + def create_lookup_key(self, method, kernel_inputs): + """Create a lookup key using LookupTableChoices""" + choices = LookupTableChoices() + return choices.make_lookup_key(kernel_inputs, method) + + def create_config(self, template_id, **kwargs): + """Create a backend configuration with template_id field""" + config = {"template_id": template_id} + + # Add minimal defaults based on template type + if template_id == "triton": + config.update( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": 64, + "num_stages": 2, + "num_warps": 2, + "EVEN_K": True, + "USE_FAST_ACCUM": False, + "ACC_TYPE": "tl.float32", + "GROUP_M": 8, + } + ) + elif template_id == "tma": + config.update( + { + "BLOCK_M": 256, + "BLOCK_N": 128, + "BLOCK_K": 64, + "num_stages": 4, + "num_warps": 8, + "EVEN_K": True, + "USE_FAST_ACCUM": False, + "ACC_TYPE": "tl.float32", + "GROUP_M": 8, + } + ) + elif template_id == "decompose_k": + config.update({"k": 4}) + + config.update(kwargs) + return config + + +@unittest.skipIf(not HAS_CUDA_AND_TRITON, "CUDA not available") +@instantiate_parametrized_tests +class TestLookupTable(BaseLookupTableTest): + """Consolidated tests for lookup table functionality""" + + def test_lookup_mismatch(self): + """Test mismatch scenario in lookup table""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + lookup_table_data = { + self.create_lookup_key("mm", kernel_inputs): [self.create_config("triton")] + } + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + # looking for addmm but created the entry with mm - should mismatch the key and return + # an empty result + result = test_choices.lookup_template_configs( + kernel_inputs, "addmm", ["triton"] + ) + self.assertEqual(result, {}) + + def test_successful_lookup_with_template_filtering(self): + """Test successful lookup that filters configs by template_id""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + config_list = [ + self.create_config("triton", BLOCK_M=128, BLOCK_N=128), + self.create_config("triton", BLOCK_M=64, BLOCK_N=64), + self.create_config("tma", BLOCK_M=256, BLOCK_N=128), + self.create_config("decompose_k", k_split=4), + ] + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): config_list} + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + + # Test triton template filtering + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"] + ) + assert result is not None, "Result should not be None" + self.assertEqual(len(result["triton"]), 2) + for config in result["triton"]: + self.assertNotIn("template_id", config) + self.assertIn("BLOCK_M", config) + + # Test tma template filtering + result = test_choices.lookup_template_configs(kernel_inputs, "mm", ["tma"]) + assert result is not None, "Result should not be None" + self.assertEqual(len(result["tma"]), 1) + self.assertNotIn("template_id", result["tma"][0]) + self.assertEqual(result["tma"][0]["BLOCK_M"], 256) + + # Test decompose_k template filtering + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["decompose_k"] + ) + assert result is not None, "Result should not be None" + self.assertEqual(len(result["decompose_k"]), 1) + self.assertNotIn("template_id", result["decompose_k"][0]) + self.assertEqual(result["decompose_k"][0]["k_split"], 4) + + def test_empty_table(self): + """Test when template lookup table is empty""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + with patch.object(inductor_config.lookup_table, "table", {}): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"] + ) + self.assertEqual(result, {}) + + def test_validation_error(self): + """Test validation error for invalid config""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + invalid_config = {"BLOCK_M": 128} # missing template_id + + lookup_table_data = { + self.create_lookup_key("mm", kernel_inputs): [invalid_config] + } + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + with self.assertRaises(ValueError) as cm: + test_choices.lookup_template_configs(kernel_inputs, "mm", ["triton"]) + self.assertIn("missing required 'template_id' field", str(cm.exception)) + + def test_cpu_input_returns_empty(self): + """Test that CPU tensor input returns empty dict""" + # Create kernel inputs with CPU tensors + kernel_inputs = self.create_mock_mm_kernel_inputs(device=torch.device("cpu")) + + lookup_table_data = { + self.create_lookup_key("mm", kernel_inputs): [self.create_config("triton")] + } + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"] + ) + self.assertEqual(result, {}) # Should return empty dict for CPU + + def test_multiple_calls_work(self): + """Test that calling lookup functions multiple times works correctly""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + config_list = [ + self.create_config("triton", BLOCK_M=128), + self.create_config("tma", BLOCK_M=256), + ] + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): config_list} + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + + # First calls + result1 = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"] + ) + result2 = test_choices.lookup_template_configs(kernel_inputs, "mm", ["tma"]) + assert result1 is not None, "Result1 should not be None" + assert result2 is not None, "Result2 should not be None" + self.assertEqual(len(result1["triton"]), 1) + self.assertEqual(len(result2["tma"]), 1) + + # Second calls should work the same + result3 = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"] + ) + result4 = test_choices.lookup_template_configs(kernel_inputs, "mm", ["tma"]) + assert result3 is not None, "Result3 should not be None" + assert result4 is not None, "Result4 should not be None" + self.assertEqual(len(result3["triton"]), 1) + self.assertEqual(len(result4["tma"]), 1) + + def test_batch_lookup_mixed_entries(self): + """Test batch lookup where some templates have entries and others don't""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + config_list = [ + self.create_config("triton", BLOCK_M=128), + self.create_config("tma", BLOCK_M=256), + # No decompose_k config in lookup table + ] + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): config_list} + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + + # Test batch lookup with mixed results + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton", "tma", "decompose_k"] + ) + assert result is not None, "Result should not be None" + + # Should have entries for triton and tma, but not decompose_k + self.assertIn("triton", result) + self.assertIn("tma", result) + self.assertNotIn("decompose_k", result) + + self.assertEqual(len(result["triton"]), 1) + self.assertEqual(len(result["tma"]), 1) + self.assertEqual(result["triton"][0]["BLOCK_M"], 128) + self.assertEqual(result["tma"][0]["BLOCK_M"], 256) + + @parametrize( + "config_hash,template_hash,expected_kept", + [ + # Hash matching (config kept) + ("hash123", "hash123", True), + # Hash mismatch (config filtered) + ("hash123", "hash456", False), + # Config without hash (config kept) + (None, "hash123", True), + # Template without hash (config kept) + ("hash123", None, True), + # Both None (config kept) + (None, None, True), + ], + ) + def test_template_hash_checking(self, config_hash, template_hash, expected_kept): + """Test template hash validation behavior""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + config = self.create_config("triton", BLOCK_M=128, BLOCK_N=64) + if config_hash is not None: + config["template_hash"] = config_hash + + template_hash_map = ( + {"triton": template_hash} if template_hash is not None else {} + ) + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): [config]} + + with ( + patch.object(inductor_config.lookup_table, "table", lookup_table_data), + patch.object(inductor_config.lookup_table, "check_src_hash", True), + ): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"], template_hash_map + ) + + if expected_kept: + assert result is not None, "Result should not be None" + self.assertIn("triton", result) + self.assertEqual(len(result["triton"]), 1) + # template_hash should be removed from returned config + self.assertNotIn("template_hash", result["triton"][0]) + else: + # Config was filtered out due to hash mismatch + self.assertEqual(result, {}) + + def test_template_hash_checking_disabled(self): + """Test that hash checking is skipped when config flag is disabled""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + # Create config with mismatching hash + config = self.create_config("triton", BLOCK_M=128, template_hash="hash123") + + # Provide different template hash that would normally cause filtering + template_hash_map = {"triton": "hash456"} + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): [config]} + + with ( + patch.object(inductor_config.lookup_table, "table", lookup_table_data), + patch.object( + inductor_config.lookup_table, + "check_src_hash", + False, + ), + ): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"], template_hash_map + ) + + # Should keep config even with mismatching hash since checking is disabled + assert result is not None, "Result should not be None" + self.assertIn("triton", result) + self.assertEqual(len(result["triton"]), 1) + # template_hash should still be removed from returned config + self.assertNotIn("template_hash", result["triton"][0]) + + def test_template_hash_mixed_scenarios(self): + """Test mixed hash scenarios with multiple configs""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + config_list = [ + self.create_config( + "triton", BLOCK_M=128, template_hash="correct_hash" + ), # Should be kept + self.create_config( + "triton", BLOCK_M=64, template_hash="wrong_hash" + ), # Should be filtered + self.create_config("triton", BLOCK_M=32), # No hash, should be kept + ] + + template_hash_map = {"triton": "correct_hash"} + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): config_list} + + with ( + patch.object(inductor_config.lookup_table, "table", lookup_table_data), + patch.object(inductor_config.lookup_table, "check_src_hash", True), + ): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"], template_hash_map + ) + + assert result is not None, "Result should not be None" + self.assertIn("triton", result) + # Should keep 2 configs: the one with correct hash and the one without hash + self.assertEqual(len(result["triton"]), 2) + + # Check that kept configs have expected BLOCK_M values + kept_block_ms = [config["BLOCK_M"] for config in result["triton"]] + self.assertIn(128, kept_block_ms) # Config with correct hash + self.assertIn(32, kept_block_ms) # Config without hash + self.assertNotIn( + 64, kept_block_ms + ) # Config with wrong hash should be filtered + + # template_hash should be removed from returned configs + for config in result["triton"]: + self.assertNotIn("template_hash", config) + + @parametrize( + "config_hash,description", + [ + ("definitely_malformed_hash_!@#$%", "malformed hash"), + (12345, "non-string hash"), + ("", "empty string hash"), + (None, "missing hash field"), + ], + ) + def test_hash_checking_disabled_edge_cases(self, config_hash, description): + """Test that configs are kept when hash checking is disabled, regardless of hash validity""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + # Create config with potentially problematic hash + config = self.create_config("triton", BLOCK_M=128) + if config_hash is not None: + config["template_hash"] = config_hash + # If config_hash is None, don't add template_hash field at all + + # Provide a valid template hash that would normally be used for comparison + template_hash_map = {"triton": "valid_template_hash_abc123"} + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): [config]} + + with ( + patch.object(inductor_config.lookup_table, "table", lookup_table_data), + patch.object(inductor_config.lookup_table, "check_src_hash", False), + ): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"], template_hash_map + ) + + # Should keep config regardless of hash validity since checking is disabled + assert result is not None, f"Result should not be None for {description}" + self.assertIn( + "triton", result, f"Should have triton result for {description}" + ) + self.assertEqual( + len(result["triton"]), 1, f"Should have 1 config for {description}" + ) + # template_hash should be removed from returned config + self.assertNotIn( + "template_hash", + result["triton"][0], + f"template_hash should be removed from result for {description}", + ) + # Other config fields should be preserved + self.assertEqual( + result["triton"][0]["BLOCK_M"], + 128, + f"BLOCK_M should be preserved for {description}", + ) + + @parametrize( + "table_has_device_key,lookup_device_matches,expected_found", + [ + # Device-specific key in table, same device -> found + (True, True, True), + # Device-specific key in table, different device -> not found + (True, False, False), + # Device-agnostic key in table, same device -> found + (False, True, True), + # Device-agnostic key in table, different device -> found (device-agnostic) + (False, False, True), + ], + ) + def test_device_key_lookup_scenarios( + self, table_has_device_key, lookup_device_matches, expected_found + ): + """Test lookup behavior with device-specific vs device-agnostic keys""" + # Create kernel inputs for "device_1" (our reference device) + kernel_inputs_device1 = self.create_mock_mm_kernel_inputs() + + # Create config + config = self.create_config("triton", BLOCK_M=128) + + # Create a test choices class for generating the table key + class TableKeyChoices(LookupTableChoices): + @staticmethod + def _get_device_key(device): + if device.type != "cuda": + return None + return "device_1" # Always device_1 for table key generation + + table_key_choices = TableKeyChoices() + + # Generate table key based on whether it should include device + if table_has_device_key: + table_key = table_key_choices.make_lookup_key( + kernel_inputs_device1, "mm", include_device=True + ) + else: + table_key = table_key_choices.make_lookup_key( + kernel_inputs_device1, "mm", include_device=False + ) + + lookup_table_data = {table_key: [config]} + + # Create test choices class for the actual lookup with different device behavior + if lookup_device_matches: + + class TestChoices(LookupTableChoices): + @staticmethod + def _get_device_key(device): + if device.type != "cuda": + return None + return "device_1" + + else: + + class TestChoices(LookupTableChoices): + @staticmethod + def _get_device_key(device): + if device.type != "cuda": + return None + return "device_2" + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = TestChoices() + result = test_choices.lookup_template_configs( + kernel_inputs_device1, "mm", ["triton"] + ) + + if expected_found: + assert result is not None, ( + f"Result should not be None when expected_found={expected_found}" + ) + self.assertIn("triton", result, "Should have triton result when found") + self.assertEqual(len(result["triton"]), 1, "Should have exactly 1 config") + self.assertEqual( + result["triton"][0]["BLOCK_M"], 128, "Config should be preserved" + ) + else: + self.assertEqual( + result, + {}, + f"Should return empty dict when expected_found={expected_found}", + ) + + def test_device_key_priority(self): + """Test that device-specific keys take priority over device-agnostic keys""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + # Create two different configs + device_specific_config = self.create_config( + "triton", BLOCK_M=256 + ) # Different BLOCK_M + device_agnostic_config = self.create_config("triton", BLOCK_M=128) + + # Create a test choices instance to generate keys + key_choices = LookupTableChoices() + + # Create both key types for the same inputs + device_key = key_choices.make_lookup_key( + kernel_inputs, "mm", include_device=True + ) + device_agnostic_key = key_choices.make_lookup_key( + kernel_inputs, "mm", include_device=False + ) + + # Put both in the table + lookup_table_data = { + device_key: [device_specific_config], + device_agnostic_key: [device_agnostic_config], + } + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"] + ) + + # Should get device-specific config (BLOCK_M=256), not device-agnostic (BLOCK_M=128) + assert result is not None, "Result should not be None" + self.assertIn("triton", result) + self.assertEqual(len(result["triton"]), 1) + self.assertEqual( + result["triton"][0]["BLOCK_M"], + 256, + "Should use device-specific config when both exist", + ) + + def test_make_lookup_key_variants(self): + """Test the make_lookup_key_variants helper function""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + test_choices = LookupTableChoices() + device_key, device_agnostic_key = test_choices.make_lookup_key_variants( + kernel_inputs, "mm" + ) + + # Both should be strings + self.assertIsInstance(device_key, str) + self.assertIsInstance(device_agnostic_key, str) + + # Device key should be longer (contains device info) + self.assertGreater(len(device_key), len(device_agnostic_key)) + + # Device-agnostic key should be contained in device key (as a substring after device part) + self.assertIn(device_agnostic_key.split("+mm")[0], device_key) + + +class UnifiedModel(nn.Module): + """Unified model for different matrix operations""" + + def __init__(self, operation="mm"): + super().__init__() + self.operation = operation + + def forward(self, *args): + if self.operation == "mm": + return torch.mm(args[0], args[1]) + elif self.operation == "addmm": + return torch.addmm(args[0], args[1], args[2]) + elif self.operation == "bmm": + return torch.bmm(args[0], args[1]) + elif self.operation == "mm_plus_mm": + return torch.mm(args[0], args[1]) + torch.mm(args[2], args[3]) + else: + raise ValueError(f"Unsupported operation: {self.operation}") + + +def verify_choice_names(choices: list[Any], pattern: str, expected_count: int = 1): + """Verify choices match expected pattern and count""" + if len(choices) != expected_count: + raise ValueError(f"Expected {expected_count} choices, got {len(choices)}") + for choice in choices: + if not re.search(pattern, choice.name): + raise ValueError( + f"Choice name '{choice.name}' doesn't match pattern '{pattern}'" + ) + return choices + + +class BaseE2ELookupTableTest(BaseLookupTableTest): + """Base class for E2E lookup table tests""" + + def setUp(self): + torch._dynamo.reset() + clear_preprocessing_fns() + self.device = torch.device("cuda") + self.dev_key = LookupTableChoices._get_device_key(self.device) + self.original_lookup_table = inductor_config.lookup_table.table + # Set the lookup table choices handler + V.set_choices_handler(LookupTableChoices()) + + def tearDown(self): + inductor_config.lookup_table.table = self.original_lookup_table + # Restore original choices handler + V.set_choices_handler(InductorChoices()) + clear_preprocessing_fns() + + def create_tensors(self, operation, b=8, m=64, n=64, k=32): + """Create test tensors for operations with configurable dimensions""" + if operation in ["mm", "addmm", "mm_plus_mm"]: + A = torch.randn(m, k, device=self.device, dtype=torch.float16) + B = torch.randn(k, n, device=self.device, dtype=torch.float16) + if operation == "mm": + return [A, B] + if operation == "addmm": + return [ + torch.randn((m, n), device=self.device, dtype=torch.float16), + A, + B, + ] + elif operation == "mm_plus_mm": + return [ + A, + B, + torch.randn(m, k, device=self.device, dtype=torch.float16), + torch.randn(k, n, device=self.device, dtype=torch.float16), + ] + elif operation == "bmm": + return [ + torch.randn(b, m, k, device=self.device, dtype=torch.float16), + torch.randn(b, k, n, device=self.device, dtype=torch.float16), + ] + else: + raise ValueError(f"Unsupported operation: {operation}") + + def setup_lookup_table(self, operation, tensors, configs): + """Setup lookup table with configuration""" + scalars = {} + if operation in ["addmm", "baddbmm"]: + scalars["beta"] = 1 + scalars["alpha"] = 1 + mock_kernel_inputs = MockMMKernelInputs(tensors, scalars) + flat_key = self.create_lookup_key(operation, mock_kernel_inputs) + inductor_config.lookup_table.table = {flat_key: configs} + + def run_model(self, operation, tensors, config_patches=None): + """Run compiled model with configuration""" + config = {"max_autotune_gemm": True, "test_configs.max_mm_configs": 4} + if config_patches: + config.update(config_patches) + + model = UnifiedModel(operation) + with inductor_config.patch(config): + compiled_model = torch.compile(model.to(self.device)) + return compiled_model(*tensors) + + def create_basic_config(self, template_id): + """Create basic configuration for template""" + configs = { + torch._inductor.kernel.mm.mm_template.uid: { + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_K": 32, + "num_stages": 2, + "num_warps": 2, + "EVEN_K": True, + "USE_FAST_ACCUM": False, + "ACC_TYPE": "tl.float32", + "GROUP_M": 8, + }, + torch._inductor.kernel.mm_plus_mm.mm_plus_mm_template.uid: { + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_K": 32, + "num_stages": 2, + "num_warps": 2, + "EVEN_K": True, + "USE_FAST_ACCUM": False, + "ACC_TYPE": "tl.float32", + "GROUP_M": 8, + }, + torch._inductor.kernel.bmm.bmm_template.uid: { + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_K": 64, + "num_stages": 2, + "num_warps": 2, + "EVEN_K": True, + "USE_FAST_ACCUM": False, + "ACC_TYPE": "tl.float32", + "GROUP_M": 8, + }, + torch._inductor.kernel.mm.persistent_tma_mm_template.uid: { + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_K": 32, + "num_stages": 2, + "num_warps": 2, + "EVEN_K": True, + "USE_FAST_ACCUM": False, + "ACC_TYPE": "tl.float32", + "GROUP_M": 8, + "A_ROW_MAJOR": True, + "B_ROW_MAJOR": True, + "NUM_SMS": get_num_sms(), + "TMA_SIZE": TMA_DESCRIPTOR_SIZE, + "TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(), + }, + torch._inductor.kernel.mm.aten_bias_addmm.uid: {}, + torch._inductor.kernel.mm.decompose_k_subgraph_template.uid: {"k_split": 4}, + } + return {"template_id": template_id, **configs.get(template_id, {})} + + def _create_simple_matmul_model(self): + """Create a simple matmul model for recording tests""" + + class SimpleMatmul(nn.Module): + def forward(self, a, b): + return torch.mm(a, b) + + return SimpleMatmul() + + def _create_test_inputs(self, device="cuda"): + """Create test inputs for matmul""" + return [ + torch.randn(512, 512, device=device, dtype=torch.float32), + torch.randn(512, 512, device=device, dtype=torch.float32), + ] + + +@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support lookup table") +@unittest.skipIf(not HAS_CUDA_AND_TRITON, "CUDA not available") +@instantiate_parametrized_tests +class TestLookupTableE2E(BaseE2ELookupTableTest): + """E2E tests for lookup table functionality""" + + @parametrize("max_autotune", [True, False]) + @fresh_cache() + def test_no_lookup_table_entry_autotune_modes(self, max_autotune): + """Test when there's no lookup table entry with different autotune modes""" + tensors = self.create_tensors("mm") + + # Setup lookup table with different key to force no match + self.setup_lookup_table( + "mm", + [ + torch.randn(64, 64, device=self.device), + torch.randn(64, 64, device=self.device), + ], + [], + ) + + # Inline validation function + def validate_choices(choices): + if max_autotune: + assert len(choices) > 2, ( + f"Max-autotune should have >2 choices, got {len(choices)}" + ) + assert any(isinstance(c, ExternKernelCaller) for c in choices), ( + "Should have ExternKernelCaller" + ) + assert any(isinstance(c, TritonTemplateCaller) for c in choices), ( + "Should have TritonTemplateCaller" + ) + else: + assert len(choices) == 1, ( + f"No max-autotune should have 1 choice, got {len(choices)}" + ) + assert isinstance(choices[0], ExternKernelCaller), ( + f"Should be ExternKernelCaller, got {type(choices[0])}" + ) + return choices + + add_preprocessing_fn(validate_choices) + self.run_model( + "mm", + tensors, + {"max_autotune_gemm": max_autotune, "max_autotune": max_autotune}, + ) + + @parametrize("operation", ["mm", "addmm", "bmm", "mm_plus_mm"]) + @fresh_cache() + def test_valid_lookup_table_entry(self, operation): + """Test when there's a valid entry for the operation""" + k = 256 if operation == "mm_plus_mm" else 64 + tensors = self.create_tensors(operation, k=k) + + # Map operation to actual template UID + template_mapping = { + "mm": torch._inductor.kernel.mm.mm_template.uid, + "addmm": torch._inductor.kernel.mm.mm_template.uid, + "bmm": torch._inductor.kernel.bmm.bmm_template.uid, + "mm_plus_mm": torch._inductor.kernel.mm_plus_mm.mm_plus_mm_template.uid, + } + template_id = template_mapping[operation] + config = self.create_basic_config(template_id) + + self.setup_lookup_table(operation, tensors, [config]) + add_preprocessing_fn( + partial(verify_choice_names, pattern="triton_", expected_count=1) + ) + self.run_model(operation, tensors) + + @unittest.skipIf(not has_triton_tma_device(), "Need TMA support") + @parametrize("operation", ["mm", "addmm"]) + @fresh_cache() + def test_tma_lookup_table_entry(self, operation): + """Test TMA template entry""" + tensors = self.create_tensors(operation) + config = self.create_basic_config( + torch._inductor.kernel.mm.persistent_tma_mm_template.uid + ) + + self.setup_lookup_table(operation, tensors, [config]) + add_preprocessing_fn( + partial( + verify_choice_names, + pattern="triton_mm_persistent_tma_", + expected_count=1, + ) + ) + self.run_model( + operation, tensors, {"triton.enable_persistent_tma_matmul": True} + ) + + @fresh_cache() + def test_decompose_k_lookup_table_entry(self): + """Test decompose_k template entry""" + tensors = self.create_tensors("mm", m=32, n=32, k=32 * 32) + config = self.create_basic_config( + torch._inductor.kernel.mm.decompose_k_subgraph_template.uid + ) + + self.setup_lookup_table("mm", tensors, [config]) + add_preprocessing_fn( + partial( + verify_choice_names, pattern="decompose_k|bmm_dtype", expected_count=1 + ) + ) + self.run_model("mm", tensors) + + @fresh_cache() + def test_bias_addmm_lookup_table_entry(self): + """Test bias_addmm template entry""" + # Create bias with stride[0] == 0 for bias_addmm eligibility + bias_unexpanded = torch.randn(64, device=self.device, dtype=torch.float16) + expanded_bias = bias_unexpanded.expand(64, 64) + tensors = [ + expanded_bias, + torch.randn(64, 32, device=self.device, dtype=torch.float16), + torch.randn(32, 64, device=self.device, dtype=torch.float16), + ] + + config = self.create_basic_config(torch._inductor.kernel.mm.aten_bias_addmm.uid) + self.setup_lookup_table("addmm", tensors, [config]) + add_preprocessing_fn( + partial(verify_choice_names, pattern="bias_addmm", expected_count=1) + ) + + # Run with original unexpanded bias + with inductor_config.patch( + {"max_autotune_gemm": True, "triton.autotune_cublasLt": True} + ): + model = UnifiedModel("addmm") + compiled_model = torch.compile(model.to(self.device), mode="max-autotune") + compiled_model(bias_unexpanded, tensors[1], tensors[2]) + + @unittest.skipIf(not has_triton_tma_device(), "Need TMA support") + @fresh_cache() + def test_multiple_configs_same_template(self): + """Test multiple configurations for same template""" + tensors = self.create_tensors("mm") + + config1 = self.create_basic_config( + torch._inductor.kernel.mm.persistent_tma_mm_template.uid + ) + config1.update({"BLOCK_M": 128, "BLOCK_N": 128, "num_warps": 8}) + + config2 = self.create_basic_config( + torch._inductor.kernel.mm.persistent_tma_mm_template.uid + ) + config2.update({"BLOCK_M": 64, "BLOCK_N": 64, "num_warps": 4}) + + self.setup_lookup_table("mm", tensors, [config1, config2]) + add_preprocessing_fn( + partial( + verify_choice_names, + pattern="triton_mm_persistent_tma_", + expected_count=2, + ) + ) + self.run_model("mm", tensors, {"triton.enable_persistent_tma_matmul": True}) + + @unittest.skipIf(not has_triton_tma_device(), "Need TMA support") + @fresh_cache() + def test_mixed_template_configs(self): + """Test mixing different template types""" + tensors = self.create_tensors("mm") + + triton_config = self.create_basic_config( + torch._inductor.kernel.mm.mm_template.uid + ) + triton_config.update({"BLOCK_M": 128, "num_warps": 8}) + + tma_config = self.create_basic_config( + torch._inductor.kernel.mm.persistent_tma_mm_template.uid + ) + tma_config.update({"BLOCK_M": 256, "num_warps": 4}) + + self.setup_lookup_table("mm", tensors, [triton_config, tma_config]) + add_preprocessing_fn( + partial(verify_choice_names, pattern="triton_", expected_count=2) + ) + self.run_model("mm", tensors, {"triton.enable_persistent_tma_matmul": True}) + + @fresh_cache() + def test_template_hash_filtering_e2e(self): + """Test end-to-end template hash filtering in real MM operation""" + tensors = self.create_tensors("mm") + + # Get the actual src_hash from the template + actual_hash = torch._inductor.kernel.mm.mm_template.src_hash + + # Create configs - one with correct hash, one with wrong hash + correct_config = self.create_basic_config( + torch._inductor.kernel.mm.mm_template.uid + ) + correct_config.update( + {"BLOCK_M": 128, "template_hash": actual_hash} # Use actual hash + ) + + wrong_config = self.create_basic_config( + torch._inductor.kernel.mm.mm_template.uid + ) + wrong_config.update( + { + "BLOCK_M": 64, + "template_hash": "definitely_wrong_hash_12345", # Wrong hash + } + ) + + self.setup_lookup_table("mm", tensors, [correct_config, wrong_config]) + + # Should only get 1 choice since the wrong hash config gets filtered + add_preprocessing_fn( + partial(verify_choice_names, pattern="triton_", expected_count=1) + ) + + # Ensure hash checking is enabled + with patch.object(inductor_config.lookup_table, "check_src_hash", True): + self.run_model("mm", tensors) + + +if __name__ == "__main__": + from torch._inductor.utils import is_big_gpu + + if HAS_GPU and HAS_CPU and is_big_gpu(): + run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index d39c8c1065a..8c844d0da74 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -2092,6 +2092,17 @@ write_are_deterministic_algorithms_enabled = ( ) +class lookup_table: + # Lookup table for template config overrides + table: Optional[dict[str, list[dict[str, Any]]]] = None + + # Enable template src_hash checking in lookup table to prevent using stale configs. + # If True, configs with 'template_hash' field will be compared against the template's + # src_hash at runtime and filtered out if they don't match. If False, no + # hash checking is performed. + check_src_hash: bool = True + + class test_configs: force_extern_kernel_in_multi_template: bool = False diff --git a/torch/_inductor/lookup_table/README.md b/torch/_inductor/lookup_table/README.md new file mode 100644 index 00000000000..6c87a365bd8 --- /dev/null +++ b/torch/_inductor/lookup_table/README.md @@ -0,0 +1,253 @@ +# Template Lookup Table System + +The template lookup table system provides a way to pre-configure kernel template parameters for specific operations and +input configurations, bypassing the default choice generation and autotuning process. + +## Overview + +The lookup table system replaces default choice generation with pre-configured template parameters for specific +operations and input configurations. It sits orthogonal to `max-autotune(-gemm)` in the following way + +If a lookup table is provided and there is a match + +- We check whether the template(s) in the match are currently in use +- If so, we use the pre-configured template(s) and config and bypass choice generation + - If more than one choice is provided, we run autotune among the pre-configured choices +- If not, we fall back to the default choice generation process, including max-autotune(-gemm) logic + +If there is no match, we fall back to the default choice generation process, including max-autotune(-gemm) logic + +## Configuration + +Enable the system by setting both: + +```python +from torch._inductor import config +config.lookup_table.table = your_table_dict +# You also need to set it as the default choice handler +from torch._inductor.lookup_table import LookupTableChoices +torch._inductor.V.set_choices_handler(LookupTableChoices()) +``` + +### Device Key Handling + +The key schema format is described in detail in the [Key Schemas](#key-schemas) section below. + +Configure device key behavior: + +```python +# Control whether entries include device-specific keys for lookups +# Device-agnostic entries work across different GPU models +``` + +**Lookup Behavior**: During lookup, the system automatically tries both key formats: + +1. **Device-specific key** (e.g., `"NVIDIA H100+input_data+mm"`) - tried first +1. **Device-agnostic key** (e.g., `"input_data+mm"`) - tried if device-specific fails + +**Priority**: If both device-specific and device-agnostic entries exist for the same inputs, the device-specific entry +takes priority. + +**NOTE**: Device-based keys simplify hardware-specific optimization without complex build rules. Currently limited to +device name only. If you need additional conditional key attributes (e.g., CUDA version filtering), please file an issue +or submit a patch. + +## Behavior + +When the table is active, the following behavior occurs for all supported operations: + +### Match Found + +- Uses pre-configured choices from the table instead of generating default choices +- Bypasses autotuning if only a single choice is provided +- If multiple choices are provided, autotuning occurs among those choices only + +### No Match Found + +- Standard default behavior - generates choices using heuristics and max-autotune settings + +### Table Not Set or Inactive + +- Standard default behavior - generates choices using heuristics and max-autotune settings + +## Supported Operations + +Currently supports: `mm`, `addmm`, `bmm`, `mm_plus_mm`, `scaled_mm` operations with + +- Triton +- ATEN +- DecomposeK + +## Table Format + +The table is a dictionary with keys in the format: + +``` +"input_key+op_name" +``` + +Where: + +- `input_key`: Generated from `KernelInputs.key` property, represents tensor shapes/dtypes/strides +- `op_name`: Operation name (`"mm"`, `"addmm"`, etc.) + +Each value is a list of configuration dictionaries containing: + +- `template_id`: Template identifier (`"triton:mm"`, `"triton::mm_persistent_tma"`, `"decompose_k"`, etc.) +- Template-specific parameters (`BLOCK_M`, `BLOCK_N`, `BLOCK_K`, `num_warps`, etc.) + +## Key Schemas + +**NOTE**: The key schema format is subject to change as the system evolves. + +The lookup table uses composite keys to match kernel configurations. See +[Implementation Details](#implementation-details) below for more technical information about key generation. This +section describes the structure of these keys. + +### Key Format Structure + +Keys follow the pattern: + +``` +[device_name+]input_key+[additional_params+]op_name +``` + +Components: + +- **device_name** (optional): GPU device identifier (e.g., `"NVIDIA H100"`) + + - Obtained from `torch.cuda.get_device_properties().gcnArchName` + - Enables device-specific optimizations + - When omitted, creates device-agnostic entries that work across hardware + +- **input_key**: Tensor configuration representation from `KernelInputs.key` + + - Format: `((dtype, shape, stride), (dtype, shape, stride), ...)` + - Each tuple represents one input tensor's properties + - Example: `((torch.float16, [128, 256], [0, 1]), (torch.float16, [64, 256], [256, 1]))` + - Order matches the operation's input argument order + +- **additional_params** (optional): Operation-specific parameters + + - Format: `key1=value1&key2=value2` + - Example: `alpha=1&beta=1` for addmm operations + +- **op_name**: Operation identifier + + - Examples: `"mm"`, `"addmm"`, `"bmm"`, `"mm_plus_mm"`, `"scaled_mm"` + +### Key Examples + +**Device-specific key for addmm:** + +``` +"NVIDIA H100+((torch.float16, [128, 256], [0, 1]), (torch.float16, [128, 64], [64, 1]), (torch.float16, [64, 256], [256, 1]))+alpha=1&beta=1+addmm" +``` + +**Device-agnostic key for mm:** + +``` +"((torch.float16, [64, 128], [128, 1]), (torch.float16, [128, 256], [256, 1]))+mm" +``` + +**Key with no additional parameters:** + +``` +"((torch.float32, [512, 512], [512, 1]), (torch.float32, [512, 512], [512, 1]))+bmm" +``` + +### Lookup Strategy + +During lookup, the system tries keys in priority order: + +1. **Device-specific key** - checked first if device information is available +1. **Device-agnostic key** - fallback if device-specific lookup fails + +This allows tables to contain: + +- Device-optimized configurations (higher priority) +- Portable configurations that work across devices +- Mix of both for flexible deployment + +## Example Table + +This is an example table for a single input showing two configurations + +```python +table = { + "((torch.float16, [128, 256], [0, 1]), (torch.float16, [128, 64], [64, 1]), (torch.float16, [64, 256], [256, 1]))+alpha=1&beta=1+addmm": [ + { + "template_id": "triton::mm", + "EVEN_K": true, + "USE_FAST_ACCUM": false, + "ACC_TYPE": "tl.float32", + "num_stages": 2, + "num_warps": 4, + "BLOCK_M": 32, + "BLOCK_N": 32, + "BLOCK_K": 64, + "hint_override": null, + "GROUP_M": 8, + "template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896" + }, + { + "template_id": "aten::bias_addmm" + }, + ] +} +``` + +## Source Hashing Safety + +The lookup table system includes source hashing to prevent using stale configurations when template code changes. + +### Configuration + +- **Enabled by default**: `torch._inductor.config.lookup_table.check_src_hash = True` +- **Optional field**: Add `"template_hash"` to table entries for enhanced safety + +### Behavior + +When source hash checking is enabled: + +- Template configurations with `"template_hash"` fields are validated against current template source hashes +- Mismatched hashes indicate the template code has changed since the configuration was created +- Stale configurations are automatically filtered out with a warning message +- Configurations without hash fields are preserved for backward compatibility or if the user wants to fly looser + +### Example with Template Hash + +```python +{ + "template_id": "triton::mm", + "BLOCK_M": 32, + "BLOCK_N": 32, + "BLOCK_K": 16, + "template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896" +} +``` + +## Performance Impact + +- **Lookup Hit**: Eliminates heuristic choice generation and autotuning overhead (if a single choice) +- **Lookup Miss**: Default behavior, including heuristic choice generation and autotuning +- **Memory**: Table stored in memory, minimal overhead for key generation and lookup + +## Implementation Details + +### Key Generation + +- Device key: Uses `torch.cuda.get_device_properties().gcnArchName` (e.g., "NVIDIA H100") +- Input key: Generated from `KernelInputs.key` containing tensor properties + +### Entry Points + +The system is accessed through: + +- `lookup_template_configs(kernel_inputs, op_name, template_uids)` - Main lookup function +- `LookupTableChoices._finalize_template_configs()` - Integration point with existing choice system + +### Error Handling + +- Validates config dictionaries contain required `template_id` field +- Gracefully handles non-CUDA devices by returning empty results diff --git a/torch/_inductor/lookup_table/__init__.py b/torch/_inductor/lookup_table/__init__.py new file mode 100644 index 00000000000..0ebb1d5618b --- /dev/null +++ b/torch/_inductor/lookup_table/__init__.py @@ -0,0 +1,32 @@ +""" +Template lookup table system for PyTorch Inductor. + +This package provides functionality for: +- Loading pre-configured template choices from lookup tables +- Managing template configurations and choices + +All functionality is contained within the LookupTableChoices class. +You can customize any aspect by subclassing LookupTableChoices and overriding methods. + +Usage: + # Basic usage + choices = LookupTableChoices() + V.set_choices_handler(choices) + + # Custom usage + class MyCustomChoices(LookupTableChoices): + def _get_lookup_table(self): + return my_custom_table + + def make_lookup_key(self, kernel_inputs, op_name, include_device=False): + return f"custom_{op_name}_{hash(str(kernel_inputs))}" + + V.set_choices_handler(MyCustomChoices()) +""" + +from .choices import LookupTableChoices + + +__all__ = [ + "LookupTableChoices", +] diff --git a/torch/_inductor/lookup_table/choices.py b/torch/_inductor/lookup_table/choices.py new file mode 100644 index 00000000000..46e54180114 --- /dev/null +++ b/torch/_inductor/lookup_table/choices.py @@ -0,0 +1,418 @@ +from __future__ import annotations + +import copy +import logging +from functools import lru_cache +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +from torch._inductor import config +from torch._inductor.choices import InductorChoices +from torch._inductor.kernel_template_choice import KernelTemplateChoice +from torch._inductor.template_heuristics.params import DictKernelTemplateParams + + +log = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from collections.abc import Generator + + from torch._inductor.codegen.common import KernelTemplate + from torch._inductor.kernel_inputs import KernelInputs + from torch._inductor.select_algorithm import ExternKernelChoice + + +class LookupTableChoices(InductorChoices): + """ + InductorChoices subclass that uses lookup table when available, otherwise falls back to parent. + All lookup functionality is contained within this class and can be customized by overriding methods. + """ + + def _get_lookup_table(self) -> dict[str, list[dict[str, Any]]]: + """ + Get the template lookup table from config. + Override this method to use custom lookup table sources (database, API, etc.). + """ + if not torch.cuda.is_available() or config.lookup_table.table is None: + return {} + return config.lookup_table.table + + @staticmethod + @lru_cache + def _get_device_key(device: torch.device) -> Optional[str]: + """ + Generate a device key for lookup table indexing. + For CPU devices, returns None. + For CUDA devices, returns the props.gcnArchName string. + """ + if device.type != "cuda": + # only cuda devices are supported, this indicates that the system is not in use + # for this device + return None + + # Get CUDA device properties + props = torch.cuda.get_device_properties(device.index) + return props.gcnArchName + + @staticmethod + def _generate_kernel_inputs_key(kernel_inputs: KernelInputs) -> str: + """ + Generate a key based on input node properties and scalars. + The key includes dtype, size, and stride information for each input node, + plus scalar values as key=value pairs separated by & signs. + """ + # Get node information using existing methods + dtypes = kernel_inputs.dtypes() + shapes = kernel_inputs.shapes_hinted() + strides = kernel_inputs.strides_hinted() + + # Create tuple of (dtype, shape_list, stride_list) for each node + node_info = tuple( + (dtype, list(shape), list(stride)) + for dtype, shape, stride in zip(dtypes, shapes, strides) + ) + + # Create base key from node information + fmt_key = str(node_info) + # Add scalar information if present + if kernel_inputs._scalars: + # Sort scalars for consistent key generation and join with & + scalar_parts = [ + f"{key}={value}" + for key, value in sorted(kernel_inputs._scalars.items()) + ] + scalars_key = "&".join(scalar_parts) + fmt_key = f"{fmt_key}+{scalars_key}" + + return f"{fmt_key}" + + def make_lookup_key( + self, kernel_inputs: KernelInputs, op_name: str, include_device: bool = False + ) -> Optional[str]: + """ + Create a flattened lookup key from kernel inputs and operation name. + Override this method to customize key generation. + + Args: + kernel_inputs: KernelInputs object containing input nodes and scalars + op_name: Operation name (e.g., "mm", "addmm") + include_device: Whether to include device key in the generated key + + Returns: + A string key combining device (optional), operation, and input information + """ + device = kernel_inputs.device() + dev_key = self._get_device_key(device) + if dev_key is None: + # The system does not run when dev_key is None, regardless of + # whether include_device is True or False + return None + if not include_device: + dev_key = None + + # Generate input key using our staticmethod + input_key = self._generate_kernel_inputs_key(kernel_inputs) + + # Create the flattened lookup key + if dev_key is not None: + key_parts = [dev_key, input_key, op_name] + else: + key_parts = [input_key, op_name] + + return "+".join(key_parts) + + def make_lookup_key_variants( + self, kernel_inputs: KernelInputs, op_name: str + ) -> tuple[Optional[str], Optional[str]]: + """ + Generate both device-specific and device-agnostic lookup keys. + Override this method to customize key variant generation. + + Args: + kernel_inputs: KernelInputs object containing input nodes and scalars + op_name: Operation name (e.g., "mm", "addmm") + + Returns: + Tuple of (device_key, device_agnostic_key). Either may be None if generation fails. + """ + device_key = self.make_lookup_key(kernel_inputs, op_name, include_device=True) + device_agnostic_key = self.make_lookup_key( + kernel_inputs, op_name, include_device=False + ) + + return device_key, device_agnostic_key + + @staticmethod + def _entry_is_valid( + cfg: dict[str, Any], + template_id: str, + template_hash_map: Optional[dict[str, Optional[str]]], + ) -> bool: + """ + Check if a config entry is valid based on template hash validation. + + Args: + cfg: Configuration dictionary that may contain a template_hash field + template_id: The template identifier + template_hash_map: Optional mapping from template_uid to src_hash for validation + + Returns: + True if the config is valid and should be kept, False if it should be filtered out + """ + # If hash checking is disabled or no hash map provided, keep the config + if not config.lookup_table.check_src_hash or not template_hash_map: + return True + + template_hash = template_hash_map.get(template_id) + config_hash = cfg.get("template_hash") + + # Both hashes present - validate they match + if template_hash is not None and config_hash is not None: + if config_hash != template_hash: + log.warning( + "Hash validation failed for template '%s': config_hash='%s' != template_hash='%s'. " + "Template code may have changed. Filtering out config: %s", + template_id, + config_hash, + template_hash, + {k: v for k, v in cfg.items() if k != "template_hash"}, + ) + return False + else: + log.debug( + "Hash validation passed for template '%s': hash='%s'", + template_id, + template_hash, + ) + return True + # Config has no hash - keep it + elif config_hash is None: + log.debug( + "Config for template '%s' has no hash - keeping it (template_hash='%s')", + template_id, + template_hash, + ) + return True + # Template has no hash - keep config + else: + log.debug( + "Template '%s' has no src_hash - keeping config with hash '%s'", + template_id, + config_hash, + ) + return True + + def lookup_template_configs( + self, + kernel_inputs: KernelInputs, + op_name: str, + template_uids: list[str], + template_hash_map: Optional[dict[str, Optional[str]]] = None, + ) -> dict[str, list[dict[str, Any]]]: + """ + Unified function to look up template configurations for multiple templates. + Override this method to customize lookup logic. + + Args: + kernel_inputs: KernelInputs object containing input nodes and scalars + op_name: Operation name (e.g., "mm", "addmm") + template_uids: List of template identifiers (e.g., ["mm", "tma", "decompose_k"]) + template_hash_map: Optional mapping from template_uid to src_hash for validation + + Returns: + {}: No lookup table in use, or no matches found for any template + {"template_uid1": [config1, config2], ...}: Matches found, filtered configurations + """ + lookup_table = self._get_lookup_table() + if not lookup_table: + log.debug("Lookup table: no table configured or CUDA unavailable") + return {} + + # Try both key variants: device-specific first, then device-agnostic + # If both exist, device-specific takes priority + device_key, device_agnostic_key = self.make_lookup_key_variants( + kernel_inputs, op_name + ) + + config_list = [] + + for key_type, key in [ + ("device-specific", device_key), + ("device-agnostic", device_agnostic_key), + ]: + if key is not None: + config_list = lookup_table.get(key, []) + if config_list: + log.debug( + "Lookup table: found %d configs using %s key '%s' for %s", + len(config_list), + key_type, + key, + op_name, + ) + break + else: + log.debug( + "Lookup table: no match for %s (tried keys: %s, %s) (table has %d keys)", + op_name, + device_key, + device_agnostic_key, + len(lookup_table), + ) + return {} + + log.debug( + "Lookup table: found %d configs for %s templates %s", + len(config_list), + op_name, + template_uids, + ) + # Group configs by template_id + configs_by_template: dict[str, list[dict[str, Any]]] = {} + for cfg in config_list: + if not isinstance(cfg, dict): + raise ValueError( + f"Config for {op_name} operation is not a dictionary: {cfg}" + ) + if "template_id" not in cfg: + raise ValueError( + f"Config for {op_name} operation missing required 'template_id' field: {cfg}" + ) + + template_id = cfg["template_id"] + if template_id in template_uids: + if template_id not in configs_by_template: + configs_by_template[template_id] = [] + configs_by_template[template_id].append(cfg) + + # Check template hashes and clean up template_id field + result = {} + for template_id, matching_configs in configs_by_template.items(): + filtered_configs = [] + for cfg in matching_configs: + # Check template hash using helper function + if not self._entry_is_valid(cfg, template_id, template_hash_map): + continue + + # Return a copy of the config, as we don't want to modify the original + cconfig = copy.deepcopy(cfg) + # Lastly, we have to throw out the template_id, as it's not a valid kwarg + # and just used to identify which template the entry belongs to + del cconfig["template_id"] + # Similarly, the template_hash is not a valid kwarg + cconfig.pop("template_hash", None) + filtered_configs.append(cconfig) + + if filtered_configs: + result[template_id] = filtered_configs + + return result + + def _finalize_template_configs( + self, + template_choices: dict[str, Generator[KernelTemplateChoice, None, None]], + kernel_inputs: KernelInputs, + templates: list[Union[KernelTemplate, ExternKernelChoice]], + op_name: str, + kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None, + ) -> list[KernelTemplateChoice]: + """Check lookup table for hits, use those if found, otherwise fall back to parent.""" + # 1. Collect template src_hashes for validation + template_uids = [template.uid for template in templates] + template_hash_map = {} + for template in templates: + src_hash = getattr(template, "src_hash", None) + template_hash_map[template.uid] = src_hash + + log.debug( + "Choices: attempting lookup for %s with %d templates", + op_name, + len(template_uids), + ) + + # 2. Single batch lookup for all templates + lookup_results = self.lookup_template_configs( + kernel_inputs, op_name, template_uids, template_hash_map + ) + + # 3. Early exit if no lookup table or no matches + if not lookup_results: # Empty dict + log.info("LookupChoices: lookup miss for %s, using fallback", op_name) + return self._fallback( + template_choices, + kernel_inputs, + templates, + op_name, + kwarg_overrides, + ) + + log.info( + "LookupChoices: lookup hit for %s - found %d/%d templates: %s", + op_name, + len(lookup_results), + len(template_uids), + list(lookup_results.keys()), + ) + + # 4. Create KTCs only for templates with lookup entries + return self._create_lookup_choices( + lookup_results, templates, kernel_inputs, op_name + ) + + def _fallback( + self, + template_choices: dict[str, Generator[KernelTemplateChoice, None, None]], + kernel_inputs: KernelInputs, + templates: list[Union[KernelTemplate, ExternKernelChoice]], + op_name: str, + kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None, + ) -> list[KernelTemplateChoice]: + """Fallback to parent if no lookup table or no matches.""" + # NOTE: this is broken out, so that subclasses are able to override this + # to handle explicitly the situations where the lookup take had a miss vs + # overriding the entire logic + return super()._finalize_template_configs( + template_choices, + kernel_inputs, + templates, + op_name, + kwarg_overrides, + ) + + def _create_lookup_choices( + self, + lookup_results: dict[str, list[dict[str, Any]]], + templates: list[Union[KernelTemplate, ExternKernelChoice]], + kernel_inputs: KernelInputs, + op_name: str, + ) -> list[KernelTemplateChoice]: + """Create KernelTemplateChoice objects from lookup results using parent's get_ktc method.""" + templates_by_uid = {template.uid: template for template in templates} + lookup_choices: list[KernelTemplateChoice] = [] + + for template_uid, configs in lookup_results.items(): + template = templates_by_uid[template_uid] + + # Use parent's get_ktc method to get a generator, then get the first base KTC + ktc_generator = self.get_ktc(kernel_inputs, template, op_name) + + try: + base_ktc = next(ktc_generator) + except StopIteration: + # No configs from heuristic, skip this template + continue + + # For each lookup config, create a KTC with the override kwargs + for c in configs: + lookup_ktc = KernelTemplateChoice( + template=base_ktc.template, + # use the ones from the lookup table + params=DictKernelTemplateParams(c), + extra_kwargs=base_ktc.extra_kwargs, + layout=base_ktc.layout, + inputs=base_ktc.inputs, + ) + lookup_choices.append(lookup_ktc) + + return lookup_choices