mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Include post grad gm and fx runnable in cache artifacts for tlparse (#151469)
Fixed #151462 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151469 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
ee3366dbb2
commit
ef64beb232
|
|
@ -14,7 +14,7 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42490000000,0.025
|
|||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,25120000000,0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,25505620920,0.015
|
||||
|
||||
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ update_hint_regression,compile_time_instruction_count,1608000000,0.02
|
|||
|
||||
|
||||
|
||||
float_args,compile_time_instruction_count,421822993,0.015
|
||||
float_args,compile_time_instruction_count,436306379,0.015
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
|
@ -902,6 +902,8 @@ def forward(self, x, y):
|
|||
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "fx_graph_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
{"artifact": {"name": "aotautograd_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||
|
|
|
|||
|
|
@ -1125,6 +1125,18 @@ class FxGraphCache:
|
|||
output_code_log.debug("Output code: \n%s", code)
|
||||
output_code_log.debug("Output code written to: %s", artifact_path)
|
||||
# On cache hit, use artifact path as filename
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "fx_graph_runnable",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: graph.runnable_graph_str,
|
||||
)
|
||||
trace_structured(
|
||||
"inductor_post_grad_graph",
|
||||
payload_fn=lambda: graph.inductor_post_grad_graph_str,
|
||||
)
|
||||
trace_structured(
|
||||
"inductor_output_code",
|
||||
lambda: {"filename": artifact_path},
|
||||
|
|
|
|||
|
|
@ -1061,12 +1061,11 @@ class _InProcessFxCompile(FxCompile):
|
|||
f"graph {graph_id}",
|
||||
)
|
||||
|
||||
def log_graph_runnable() -> str:
|
||||
fd = io.StringIO()
|
||||
torch._dynamo.repro.after_aot.save_graph_repro(
|
||||
fd, gm, example_inputs, "inductor", save_dir=None
|
||||
)
|
||||
return fd.getvalue()
|
||||
runnable_graph_str = fd.getvalue()
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
|
|
@ -1074,7 +1073,7 @@ class _InProcessFxCompile(FxCompile):
|
|||
"name": "fx_graph_runnable",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: log_graph_runnable(),
|
||||
payload_fn=lambda: runnable_graph_str,
|
||||
)
|
||||
|
||||
V.debug.fx_graph(gm, example_inputs)
|
||||
|
|
@ -1134,11 +1133,12 @@ class _InProcessFxCompile(FxCompile):
|
|||
colored=True,
|
||||
),
|
||||
)
|
||||
inductor_post_grad_graph_str = gm.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
)
|
||||
trace_structured(
|
||||
"inductor_post_grad_graph",
|
||||
payload_fn=lambda: gm.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
),
|
||||
payload_fn=lambda: inductor_post_grad_graph_str,
|
||||
)
|
||||
if config.trace.enabled:
|
||||
provenance_tracking_json = (
|
||||
|
|
@ -1400,6 +1400,8 @@ class _InProcessFxCompile(FxCompile):
|
|||
static_input_idxs,
|
||||
graph_kwargs,
|
||||
inputs_to_check,
|
||||
runnable_graph_str,
|
||||
inductor_post_grad_graph_str,
|
||||
recursively_apply_fns,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -398,6 +398,10 @@ class CompiledFxGraph(OutputCode):
|
|||
recursively_apply_fns: Optional[Callable[..., Any]]
|
||||
cache_key: str
|
||||
source_code: str = dataclasses.field(repr=False) # Do not display source_code
|
||||
runnable_graph_str: str = dataclasses.field(repr=False) # Do not display graph
|
||||
inductor_post_grad_graph_str: str = dataclasses.field(
|
||||
repr=False
|
||||
) # Do not display graph
|
||||
cache_linemap: Optional[list[tuple[int, str]]]
|
||||
device_types: OrderedSet[str]
|
||||
device_idxs: OrderedSet[int]
|
||||
|
|
@ -439,6 +443,8 @@ class CompiledFxGraph(OutputCode):
|
|||
static_input_idxs: Sequence[int],
|
||||
fx_kwargs: _CompileFxKwargs,
|
||||
inputs_to_check: Sequence[int],
|
||||
runnable_graph_str: str,
|
||||
inductor_post_grad_graph_str: str,
|
||||
recursively_apply_fns: Optional[Callable[..., Any]] = None,
|
||||
) -> None:
|
||||
self.current_callable = current_callable
|
||||
|
|
@ -447,6 +453,8 @@ class CompiledFxGraph(OutputCode):
|
|||
if graph.cache_path:
|
||||
with open(graph.cache_path) as f:
|
||||
self.source_code = f.read()
|
||||
self.runnable_graph_str = runnable_graph_str
|
||||
self.inductor_post_grad_graph_str = inductor_post_grad_graph_str
|
||||
self.cache_linemap = graph.cache_linemap
|
||||
# TODO - ordered set
|
||||
self.device_types = OrderedSet(graph.device_types)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user