mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Include post grad gm and fx runnable in cache artifacts for tlparse (#151469)
Fixed #151462 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151469 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
ee3366dbb2
commit
ef64beb232
|
|
@ -14,7 +14,7 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42490000000,0.025
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor_gpu,compile_time_instruction_count,25120000000,0.015
|
add_loop_inductor_gpu,compile_time_instruction_count,25505620920,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -38,7 +38,7 @@ update_hint_regression,compile_time_instruction_count,1608000000,0.02
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
float_args,compile_time_instruction_count,421822993,0.015
|
float_args,compile_time_instruction_count,436306379,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -902,6 +902,8 @@ def forward(self, x, y):
|
||||||
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
|
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
|
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
{"artifact": {"name": "fx_graph_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"artifact": {"name": "fx_graph_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
{"artifact": {"name": "aotautograd_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"artifact": {"name": "aotautograd_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
|
|
|
||||||
|
|
@ -1125,6 +1125,18 @@ class FxGraphCache:
|
||||||
output_code_log.debug("Output code: \n%s", code)
|
output_code_log.debug("Output code: \n%s", code)
|
||||||
output_code_log.debug("Output code written to: %s", artifact_path)
|
output_code_log.debug("Output code written to: %s", artifact_path)
|
||||||
# On cache hit, use artifact path as filename
|
# On cache hit, use artifact path as filename
|
||||||
|
trace_structured(
|
||||||
|
"artifact",
|
||||||
|
metadata_fn=lambda: {
|
||||||
|
"name": "fx_graph_runnable",
|
||||||
|
"encoding": "string",
|
||||||
|
},
|
||||||
|
payload_fn=lambda: graph.runnable_graph_str,
|
||||||
|
)
|
||||||
|
trace_structured(
|
||||||
|
"inductor_post_grad_graph",
|
||||||
|
payload_fn=lambda: graph.inductor_post_grad_graph_str,
|
||||||
|
)
|
||||||
trace_structured(
|
trace_structured(
|
||||||
"inductor_output_code",
|
"inductor_output_code",
|
||||||
lambda: {"filename": artifact_path},
|
lambda: {"filename": artifact_path},
|
||||||
|
|
|
||||||
|
|
@ -1061,12 +1061,11 @@ class _InProcessFxCompile(FxCompile):
|
||||||
f"graph {graph_id}",
|
f"graph {graph_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_graph_runnable() -> str:
|
|
||||||
fd = io.StringIO()
|
fd = io.StringIO()
|
||||||
torch._dynamo.repro.after_aot.save_graph_repro(
|
torch._dynamo.repro.after_aot.save_graph_repro(
|
||||||
fd, gm, example_inputs, "inductor", save_dir=None
|
fd, gm, example_inputs, "inductor", save_dir=None
|
||||||
)
|
)
|
||||||
return fd.getvalue()
|
runnable_graph_str = fd.getvalue()
|
||||||
|
|
||||||
trace_structured(
|
trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
|
|
@ -1074,7 +1073,7 @@ class _InProcessFxCompile(FxCompile):
|
||||||
"name": "fx_graph_runnable",
|
"name": "fx_graph_runnable",
|
||||||
"encoding": "string",
|
"encoding": "string",
|
||||||
},
|
},
|
||||||
payload_fn=lambda: log_graph_runnable(),
|
payload_fn=lambda: runnable_graph_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
V.debug.fx_graph(gm, example_inputs)
|
V.debug.fx_graph(gm, example_inputs)
|
||||||
|
|
@ -1134,11 +1133,12 @@ class _InProcessFxCompile(FxCompile):
|
||||||
colored=True,
|
colored=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
inductor_post_grad_graph_str = gm.print_readable(
|
||||||
|
print_output=False, include_stride=True, include_device=True
|
||||||
|
)
|
||||||
trace_structured(
|
trace_structured(
|
||||||
"inductor_post_grad_graph",
|
"inductor_post_grad_graph",
|
||||||
payload_fn=lambda: gm.print_readable(
|
payload_fn=lambda: inductor_post_grad_graph_str,
|
||||||
print_output=False, include_stride=True, include_device=True
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if config.trace.enabled:
|
if config.trace.enabled:
|
||||||
provenance_tracking_json = (
|
provenance_tracking_json = (
|
||||||
|
|
@ -1400,6 +1400,8 @@ class _InProcessFxCompile(FxCompile):
|
||||||
static_input_idxs,
|
static_input_idxs,
|
||||||
graph_kwargs,
|
graph_kwargs,
|
||||||
inputs_to_check,
|
inputs_to_check,
|
||||||
|
runnable_graph_str,
|
||||||
|
inductor_post_grad_graph_str,
|
||||||
recursively_apply_fns,
|
recursively_apply_fns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -398,6 +398,10 @@ class CompiledFxGraph(OutputCode):
|
||||||
recursively_apply_fns: Optional[Callable[..., Any]]
|
recursively_apply_fns: Optional[Callable[..., Any]]
|
||||||
cache_key: str
|
cache_key: str
|
||||||
source_code: str = dataclasses.field(repr=False) # Do not display source_code
|
source_code: str = dataclasses.field(repr=False) # Do not display source_code
|
||||||
|
runnable_graph_str: str = dataclasses.field(repr=False) # Do not display graph
|
||||||
|
inductor_post_grad_graph_str: str = dataclasses.field(
|
||||||
|
repr=False
|
||||||
|
) # Do not display graph
|
||||||
cache_linemap: Optional[list[tuple[int, str]]]
|
cache_linemap: Optional[list[tuple[int, str]]]
|
||||||
device_types: OrderedSet[str]
|
device_types: OrderedSet[str]
|
||||||
device_idxs: OrderedSet[int]
|
device_idxs: OrderedSet[int]
|
||||||
|
|
@ -439,6 +443,8 @@ class CompiledFxGraph(OutputCode):
|
||||||
static_input_idxs: Sequence[int],
|
static_input_idxs: Sequence[int],
|
||||||
fx_kwargs: _CompileFxKwargs,
|
fx_kwargs: _CompileFxKwargs,
|
||||||
inputs_to_check: Sequence[int],
|
inputs_to_check: Sequence[int],
|
||||||
|
runnable_graph_str: str,
|
||||||
|
inductor_post_grad_graph_str: str,
|
||||||
recursively_apply_fns: Optional[Callable[..., Any]] = None,
|
recursively_apply_fns: Optional[Callable[..., Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.current_callable = current_callable
|
self.current_callable = current_callable
|
||||||
|
|
@ -447,6 +453,8 @@ class CompiledFxGraph(OutputCode):
|
||||||
if graph.cache_path:
|
if graph.cache_path:
|
||||||
with open(graph.cache_path) as f:
|
with open(graph.cache_path) as f:
|
||||||
self.source_code = f.read()
|
self.source_code = f.read()
|
||||||
|
self.runnable_graph_str = runnable_graph_str
|
||||||
|
self.inductor_post_grad_graph_str = inductor_post_grad_graph_str
|
||||||
self.cache_linemap = graph.cache_linemap
|
self.cache_linemap = graph.cache_linemap
|
||||||
# TODO - ordered set
|
# TODO - ordered set
|
||||||
self.device_types = OrderedSet(graph.device_types)
|
self.device_types = OrderedSet(graph.device_types)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user