Log Full Knapsack Problem Information (#140757)

Summary: When AOT_PARTITIONER_DEBUG is set to 1 and debug logging is turned on we can now log the full input and output for each knapsack problem.

Differential Revision: D65633086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140757
Approved by: https://github.com/jansel
This commit is contained in:
Basil Wong 2024-11-18 20:36:30 +00:00 committed by PyTorch MergeBot
parent 408ad45014
commit 00c829876c

View File

@ -1703,6 +1703,44 @@ def choose_saved_values_set(
node_info,
all_recomputable_banned_nodes,
)
if AOT_PARTITIONER_DEBUG:
max_runtime = max(
runtimes_banned_nodes
) # For normalizing runtimes in logs
input_summary = [
f"\n\t\t\t{index}, {memory}, {runtime / max_runtime}, {node.op}, {node.target}, {node.meta}, {node.args}"
for index, (memory, runtime, node) in enumerate(
zip(
memories_banned_nodes,
runtimes_banned_nodes,
all_recomputable_banned_nodes,
)
)
]
joint_graph_nodes = [node.name for node in joint_graph.nodes]
joint_graph_edges = [
(inp.name, node.name)
for node in joint_graph.nodes
for inp in node.all_input_nodes
]
knapsack_summary = f"""
Activation Checkpointing - Knapsack Problem Summary:
Input:
Solver: {config.activation_memory_budget_solver}
Max Memory: {max(config.activation_memory_budget, 0)}
Graph Nodes: {joint_graph_nodes}
Graph Edges: {joint_graph_edges}
(Index, Memory, Runtime, Node.Op, Node.Target, Metadata): {"".join(input_summary)}
Output:
Expected Runtime: {expected_runtime}
Saved Nodes: {saved_node_idxs}
Recomputable Nodes: {recomputable_node_idxs}
"""
torch._logging.trace_structured(
name="artifact",
payload_fn=lambda: knapsack_summary,
)
log.info(knapsack_summary)
dont_ban = set()
for idx in recomputable_node_idxs:
# if idx in all_recomputable_banned_nodes: