mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move prologue_supported_inputs computations to def_kernal (#150869)
This avoid replaying load_input on a cache hit on the generate_code_cache. the idea is that if a template have prologue_loads_all_inputs = True, it means that all all inputs are loaded and hence no need to replay Effect on the current benchmark on a local run on dev server. 18549985383 -> 15072230073 25697270062 -> 20738613297 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150869 Approved by: https://github.com/eellison
This commit is contained in:
parent
4421aee558
commit
4bcff4af99
|
|
@ -1188,7 +1188,7 @@ class TestMaxAutotune(TestCase):
|
|||
cache_key, events = get_cache_key_and_events()
|
||||
|
||||
if not TEST_WITH_ROCM:
|
||||
self.assertEqual(
|
||||
self.assertExpectedInline(
|
||||
remove_white_space(cache_key),
|
||||
remove_white_space(
|
||||
"""
|
||||
|
|
@ -1204,13 +1204,7 @@ class TestMaxAutotune(TestCase):
|
|||
|
||||
self.assertEqual(
|
||||
remove_white_space(events),
|
||||
remove_white_space(
|
||||
"""[
|
||||
('def_kernel', ['A', 'B'], {}),
|
||||
('load_input', ['A', 'a', ('idx_m', 'idx_n')], {'mask': 'a_mask', 'indent_width': 8}),
|
||||
('load_input', ['B', 'b', ('idx_m', 'idx_n')], {'mask': 'b_mask', 'indent_width': 8})]
|
||||
"""
|
||||
),
|
||||
remove_white_space("""[('def_kernel', ['A', 'B'], {})]"""),
|
||||
)
|
||||
|
||||
# Test symbolic shapes with different symbols. Will cache miss due to different symbols in inputs.
|
||||
|
|
@ -1232,7 +1226,7 @@ class TestMaxAutotune(TestCase):
|
|||
cache_key, events = get_cache_key_and_events()
|
||||
|
||||
if not TEST_WITH_ROCM:
|
||||
self.assertEqual(
|
||||
self.assertExpectedInline(
|
||||
remove_white_space(cache_key),
|
||||
remove_white_space(
|
||||
"""{'input_nodes': ["[[s77, s17], [s17, 1], torch.float32, device(type='cuda', index=0), 0]",
|
||||
|
|
@ -1245,16 +1239,21 @@ class TestMaxAutotune(TestCase):
|
|||
),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
self.assertExpectedInline(
|
||||
remove_white_space(events),
|
||||
remove_white_space(
|
||||
"""[('def_kernel',['A','B'],{}),('size',['A',0],{}),('size',['B',1],{}),('size',['A',1],{})]"""
|
||||
),
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
remove_white_space(events),
|
||||
remove_white_space(
|
||||
"""[
|
||||
('def_kernel', ['A', 'B'], {}),
|
||||
('size', ['A', 0], {}), ('size', ['B', 1], {}),
|
||||
('size', ['A', 1], {}),
|
||||
('load_input', ['A', 'a', ('idx_m', 'idx_n')], {'mask': 'a_mask', 'indent_width': 8}),
|
||||
('load_input', ['B', 'b', ('idx_m', 'idx_n')], {'mask': 'b_mask', 'indent_width': 8})]
|
||||
"""
|
||||
('def_kernel', ['A', 'B'], {}),
|
||||
('size', ['A', 0], {}),
|
||||
('size', ['B', 1], {}),
|
||||
('size', ['A', 1], {})]
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -230,6 +230,7 @@ mm_template = TritonTemplate(
|
|||
"""
|
||||
),
|
||||
cache_codegen_enabled_for_template=True,
|
||||
prologue_loads_all_inputs=True,
|
||||
)
|
||||
|
||||
persistent_tma_mm_template = TritonTemplate(
|
||||
|
|
|
|||
|
|
@ -311,6 +311,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
epilogue_fn=identity,
|
||||
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
prologue_loads_all_inputs=False,
|
||||
) -> None:
|
||||
numel = sympy_product(output_node.get_size())
|
||||
super().__init__(
|
||||
|
|
@ -387,6 +388,10 @@ class TritonTemplateKernel(TritonKernel):
|
|||
# Update each time an input is marked frozen, used to replay the freezing of inputs on a cache hit.
|
||||
self.frozen_layouts_cnt = 0
|
||||
|
||||
# When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel
|
||||
# by adding all inputs.
|
||||
self.prologue_loads_all_inputs = prologue_loads_all_inputs
|
||||
|
||||
def input_dependent_preserved_state(self) -> str:
|
||||
# Not adding self.args.output_buffers on purpose. But we do not need to reproduce it on a cache hit.
|
||||
# (never accessed).
|
||||
|
|
@ -428,6 +433,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
key.name: getattr(self, key.name)
|
||||
for key in dataclasses.fields(SubgraphInfo)
|
||||
}
|
||||
|
||||
assert body_name in self.subgraph_bodies, body_name
|
||||
|
||||
subgraph = self.subgraph_bodies[body_name]
|
||||
|
|
@ -585,10 +591,13 @@ class TritonTemplateKernel(TritonKernel):
|
|||
# The args may be duplicated, so renaming must be after args are de-duplicated.
|
||||
for name in argnames:
|
||||
input_node = self.named_input_nodes[name]
|
||||
if self.prologue_loads_all_inputs:
|
||||
self.prologue_supported_inputs.add(input_node.get_name())
|
||||
if input_node.get_name() in V.graph.removed_buffers:
|
||||
continue
|
||||
if input_node.get_name() in self.prologue_fused_inputs:
|
||||
continue
|
||||
|
||||
arg_name = self.args.input_buffers[input_node.get_name()]
|
||||
if input_node.get_layout().offset == 0:
|
||||
renames.writeline(f"{name} = {arg_name}")
|
||||
|
|
@ -756,7 +765,9 @@ class TritonTemplateKernel(TritonKernel):
|
|||
"""
|
||||
|
||||
input_node = self.named_input_nodes[input_name]
|
||||
self.prologue_supported_inputs.add(input_node.get_name())
|
||||
if not self.prologue_loads_all_inputs:
|
||||
self.prologue_supported_inputs.add(input_node.get_name())
|
||||
|
||||
tilings = (sympy_product(input_node.get_size()), sympy.Integer(1))
|
||||
groups = {
|
||||
"x": tilings[0],
|
||||
|
|
@ -1261,6 +1272,7 @@ class TritonTemplate(KernelTemplate):
|
|||
source: str,
|
||||
debug=False,
|
||||
cache_codegen_enabled_for_template=False,
|
||||
prologue_loads_all_inputs=False,
|
||||
) -> None:
|
||||
super().__init__(name)
|
||||
self.grid = grid
|
||||
|
|
@ -1271,6 +1283,9 @@ class TritonTemplate(KernelTemplate):
|
|||
self._cache_codegen_enabled_for_template = cache_codegen_enabled_for_template
|
||||
self._generated_code_cache: GeneratedCodeCache = GeneratedCodeCache()
|
||||
clear_on_fresh_inductor_cache(self._generated_code_cache)
|
||||
# When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel
|
||||
# by adding all inputs.
|
||||
self.prologue_loads_all_inputs = prologue_loads_all_inputs
|
||||
|
||||
# When this flag is on, we ensure that the cached results and the generated result if cache
|
||||
# was not used are the same.
|
||||
|
|
@ -1370,6 +1385,7 @@ class TritonTemplate(KernelTemplate):
|
|||
"suffix_args": suffix_args,
|
||||
"epilogue_fn": epilogue_fn,
|
||||
"subgraphs": subgraphs,
|
||||
"prologue_loads_all_inputs": self.prologue_loads_all_inputs,
|
||||
}
|
||||
|
||||
if HAS_WARP_SPEC:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user