Move prologue_supported_inputs computations to def_kernal (#150869)

This avoid replaying load_input on a cache hit on the generate_code_cache.
the idea is that if a template have prologue_loads_all_inputs = True, it means that
all all inputs are loaded and hence no need to replay

Effect on the current benchmark on a local run on dev server.
18549985383 -> 15072230073
25697270062 -> 20738613297

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150869
Approved by: https://github.com/eellison
This commit is contained in:
Laith Sakka 2025-05-21 15:55:57 -07:00 committed by PyTorch MergeBot
parent 4421aee558
commit 4bcff4af99
3 changed files with 33 additions and 17 deletions

View File

@ -1188,7 +1188,7 @@ class TestMaxAutotune(TestCase):
cache_key, events = get_cache_key_and_events()
if not TEST_WITH_ROCM:
self.assertEqual(
self.assertExpectedInline(
remove_white_space(cache_key),
remove_white_space(
"""
@ -1204,13 +1204,7 @@ class TestMaxAutotune(TestCase):
self.assertEqual(
remove_white_space(events),
remove_white_space(
"""[
('def_kernel', ['A', 'B'], {}),
('load_input', ['A', 'a', ('idx_m', 'idx_n')], {'mask': 'a_mask', 'indent_width': 8}),
('load_input', ['B', 'b', ('idx_m', 'idx_n')], {'mask': 'b_mask', 'indent_width': 8})]
"""
),
remove_white_space("""[('def_kernel', ['A', 'B'], {})]"""),
)
# Test symbolic shapes with different symbols. Will cache miss due to different symbols in inputs.
@ -1232,7 +1226,7 @@ class TestMaxAutotune(TestCase):
cache_key, events = get_cache_key_and_events()
if not TEST_WITH_ROCM:
self.assertEqual(
self.assertExpectedInline(
remove_white_space(cache_key),
remove_white_space(
"""{'input_nodes': ["[[s77, s17], [s17, 1], torch.float32, device(type='cuda', index=0), 0]",
@ -1245,16 +1239,21 @@ class TestMaxAutotune(TestCase):
),
)
self.assertEqual(
self.assertExpectedInline(
remove_white_space(events),
remove_white_space(
"""[('def_kernel',['A','B'],{}),('size',['A',0],{}),('size',['B',1],{}),('size',['A',1],{})]"""
),
)
self.assertExpectedInline(
remove_white_space(events),
remove_white_space(
"""[
('def_kernel', ['A', 'B'], {}),
('size', ['A', 0], {}), ('size', ['B', 1], {}),
('size', ['A', 1], {}),
('load_input', ['A', 'a', ('idx_m', 'idx_n')], {'mask': 'a_mask', 'indent_width': 8}),
('load_input', ['B', 'b', ('idx_m', 'idx_n')], {'mask': 'b_mask', 'indent_width': 8})]
"""
('def_kernel', ['A', 'B'], {}),
('size', ['A', 0], {}),
('size', ['B', 1], {}),
('size', ['A', 1], {})]
"""
),
)

View File

@ -230,6 +230,7 @@ mm_template = TritonTemplate(
"""
),
cache_codegen_enabled_for_template=True,
prologue_loads_all_inputs=True,
)
persistent_tma_mm_template = TritonTemplate(

View File

@ -311,6 +311,7 @@ class TritonTemplateKernel(TritonKernel):
epilogue_fn=identity,
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
workspace_arg: Optional[WorkspaceArg] = None,
prologue_loads_all_inputs=False,
) -> None:
numel = sympy_product(output_node.get_size())
super().__init__(
@ -387,6 +388,10 @@ class TritonTemplateKernel(TritonKernel):
# Update each time an input is marked frozen, used to replay the freezing of inputs on a cache hit.
self.frozen_layouts_cnt = 0
# When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel
# by adding all inputs.
self.prologue_loads_all_inputs = prologue_loads_all_inputs
def input_dependent_preserved_state(self) -> str:
# Not adding self.args.output_buffers on purpose. But we do not need to reproduce it on a cache hit.
# (never accessed).
@ -428,6 +433,7 @@ class TritonTemplateKernel(TritonKernel):
key.name: getattr(self, key.name)
for key in dataclasses.fields(SubgraphInfo)
}
assert body_name in self.subgraph_bodies, body_name
subgraph = self.subgraph_bodies[body_name]
@ -585,10 +591,13 @@ class TritonTemplateKernel(TritonKernel):
# The args may be duplicated, so renaming must be after args are de-duplicated.
for name in argnames:
input_node = self.named_input_nodes[name]
if self.prologue_loads_all_inputs:
self.prologue_supported_inputs.add(input_node.get_name())
if input_node.get_name() in V.graph.removed_buffers:
continue
if input_node.get_name() in self.prologue_fused_inputs:
continue
arg_name = self.args.input_buffers[input_node.get_name()]
if input_node.get_layout().offset == 0:
renames.writeline(f"{name} = {arg_name}")
@ -756,7 +765,9 @@ class TritonTemplateKernel(TritonKernel):
"""
input_node = self.named_input_nodes[input_name]
self.prologue_supported_inputs.add(input_node.get_name())
if not self.prologue_loads_all_inputs:
self.prologue_supported_inputs.add(input_node.get_name())
tilings = (sympy_product(input_node.get_size()), sympy.Integer(1))
groups = {
"x": tilings[0],
@ -1261,6 +1272,7 @@ class TritonTemplate(KernelTemplate):
source: str,
debug=False,
cache_codegen_enabled_for_template=False,
prologue_loads_all_inputs=False,
) -> None:
super().__init__(name)
self.grid = grid
@ -1271,6 +1283,9 @@ class TritonTemplate(KernelTemplate):
self._cache_codegen_enabled_for_template = cache_codegen_enabled_for_template
self._generated_code_cache: GeneratedCodeCache = GeneratedCodeCache()
clear_on_fresh_inductor_cache(self._generated_code_cache)
# When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel
# by adding all inputs.
self.prologue_loads_all_inputs = prologue_loads_all_inputs
# When this flag is on, we ensure that the cached results and the generated result if cache
# was not used are the same.
@ -1370,6 +1385,7 @@ class TritonTemplate(KernelTemplate):
"suffix_args": suffix_args,
"epilogue_fn": epilogue_fn,
"subgraphs": subgraphs,
"prologue_loads_all_inputs": self.prologue_loads_all_inputs,
}
if HAS_WARP_SPEC: