Include post grad gm and fx runnable in cache artifacts for tlparse (#151469)

Fixed #151462

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151469
Approved by: https://github.com/bdhirsh
This commit is contained in:
Oguz Ulgen 2025-04-16 22:19:26 -07:00 committed by PyTorch MergeBot
parent ee3366dbb2
commit ef64beb232
5 changed files with 36 additions and 12 deletions

View File

@ -14,7 +14,7 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42490000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,25120000000,0.015 add_loop_inductor_gpu,compile_time_instruction_count,25505620920,0.015
@ -38,7 +38,7 @@ update_hint_regression,compile_time_instruction_count,1608000000,0.02
float_args,compile_time_instruction_count,421822993,0.015 float_args,compile_time_instruction_count,436306379,0.015

1 add_loop_eager compile_time_instruction_count 2944000000 0.015
14 symint_sum_loop compile_time_instruction_count 4180000000 0.015
15 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2042000000 0.015
16 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5884000000 0.015
17 aotdispatcher_partitioner_cpu compile_time_instruction_count 8501000000 0.015
18 aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1856000000 0.015
19 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3751000000 0.015
20 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10200000000 0.015
38
39
40
41
42
43
44

View File

@ -902,6 +902,8 @@ def forward(self, x, y):
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aotautograd_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}

View File

@ -1125,6 +1125,18 @@ class FxGraphCache:
output_code_log.debug("Output code: \n%s", code) output_code_log.debug("Output code: \n%s", code)
output_code_log.debug("Output code written to: %s", artifact_path) output_code_log.debug("Output code written to: %s", artifact_path)
# On cache hit, use artifact path as filename # On cache hit, use artifact path as filename
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "fx_graph_runnable",
"encoding": "string",
},
payload_fn=lambda: graph.runnable_graph_str,
)
trace_structured(
"inductor_post_grad_graph",
payload_fn=lambda: graph.inductor_post_grad_graph_str,
)
trace_structured( trace_structured(
"inductor_output_code", "inductor_output_code",
lambda: {"filename": artifact_path}, lambda: {"filename": artifact_path},

View File

@ -1061,12 +1061,11 @@ class _InProcessFxCompile(FxCompile):
f"graph {graph_id}", f"graph {graph_id}",
) )
def log_graph_runnable() -> str:
fd = io.StringIO() fd = io.StringIO()
torch._dynamo.repro.after_aot.save_graph_repro( torch._dynamo.repro.after_aot.save_graph_repro(
fd, gm, example_inputs, "inductor", save_dir=None fd, gm, example_inputs, "inductor", save_dir=None
) )
return fd.getvalue() runnable_graph_str = fd.getvalue()
trace_structured( trace_structured(
"artifact", "artifact",
@ -1074,7 +1073,7 @@ class _InProcessFxCompile(FxCompile):
"name": "fx_graph_runnable", "name": "fx_graph_runnable",
"encoding": "string", "encoding": "string",
}, },
payload_fn=lambda: log_graph_runnable(), payload_fn=lambda: runnable_graph_str,
) )
V.debug.fx_graph(gm, example_inputs) V.debug.fx_graph(gm, example_inputs)
@ -1134,11 +1133,12 @@ class _InProcessFxCompile(FxCompile):
colored=True, colored=True,
), ),
) )
inductor_post_grad_graph_str = gm.print_readable(
print_output=False, include_stride=True, include_device=True
)
trace_structured( trace_structured(
"inductor_post_grad_graph", "inductor_post_grad_graph",
payload_fn=lambda: gm.print_readable( payload_fn=lambda: inductor_post_grad_graph_str,
print_output=False, include_stride=True, include_device=True
),
) )
if config.trace.enabled: if config.trace.enabled:
provenance_tracking_json = ( provenance_tracking_json = (
@ -1400,6 +1400,8 @@ class _InProcessFxCompile(FxCompile):
static_input_idxs, static_input_idxs,
graph_kwargs, graph_kwargs,
inputs_to_check, inputs_to_check,
runnable_graph_str,
inductor_post_grad_graph_str,
recursively_apply_fns, recursively_apply_fns,
) )

View File

@ -398,6 +398,10 @@ class CompiledFxGraph(OutputCode):
recursively_apply_fns: Optional[Callable[..., Any]] recursively_apply_fns: Optional[Callable[..., Any]]
cache_key: str cache_key: str
source_code: str = dataclasses.field(repr=False) # Do not display source_code source_code: str = dataclasses.field(repr=False) # Do not display source_code
runnable_graph_str: str = dataclasses.field(repr=False) # Do not display graph
inductor_post_grad_graph_str: str = dataclasses.field(
repr=False
) # Do not display graph
cache_linemap: Optional[list[tuple[int, str]]] cache_linemap: Optional[list[tuple[int, str]]]
device_types: OrderedSet[str] device_types: OrderedSet[str]
device_idxs: OrderedSet[int] device_idxs: OrderedSet[int]
@ -439,6 +443,8 @@ class CompiledFxGraph(OutputCode):
static_input_idxs: Sequence[int], static_input_idxs: Sequence[int],
fx_kwargs: _CompileFxKwargs, fx_kwargs: _CompileFxKwargs,
inputs_to_check: Sequence[int], inputs_to_check: Sequence[int],
runnable_graph_str: str,
inductor_post_grad_graph_str: str,
recursively_apply_fns: Optional[Callable[..., Any]] = None, recursively_apply_fns: Optional[Callable[..., Any]] = None,
) -> None: ) -> None:
self.current_callable = current_callable self.current_callable = current_callable
@ -447,6 +453,8 @@ class CompiledFxGraph(OutputCode):
if graph.cache_path: if graph.cache_path:
with open(graph.cache_path) as f: with open(graph.cache_path) as f:
self.source_code = f.read() self.source_code = f.read()
self.runnable_graph_str = runnable_graph_str
self.inductor_post_grad_graph_str = inductor_post_grad_graph_str
self.cache_linemap = graph.cache_linemap self.cache_linemap = graph.cache_linemap
# TODO - ordered set # TODO - ordered set
self.device_types = OrderedSet(graph.device_types) self.device_types = OrderedSet(graph.device_types)