diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 3720900763c..6a892d7fe64 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1703,6 +1703,44 @@ def choose_saved_values_set( node_info, all_recomputable_banned_nodes, ) + if AOT_PARTITIONER_DEBUG: + max_runtime = max( + runtimes_banned_nodes + ) # For normalizing runtimes in logs + input_summary = [ + f"\n\t\t\t{index}, {memory}, {runtime / max_runtime}, {node.op}, {node.target}, {node.meta}, {node.args}" + for index, (memory, runtime, node) in enumerate( + zip( + memories_banned_nodes, + runtimes_banned_nodes, + all_recomputable_banned_nodes, + ) + ) + ] + joint_graph_nodes = [node.name for node in joint_graph.nodes] + joint_graph_edges = [ + (inp.name, node.name) + for node in joint_graph.nodes + for inp in node.all_input_nodes + ] + knapsack_summary = f""" +Activation Checkpointing - Knapsack Problem Summary: + Input: + Solver: {config.activation_memory_budget_solver} + Max Memory: {max(config.activation_memory_budget, 0)} + Graph Nodes: {joint_graph_nodes} + Graph Edges: {joint_graph_edges} + (Index, Memory, Runtime, Node.Op, Node.Target, Metadata): {"".join(input_summary)} + Output: + Expected Runtime: {expected_runtime} + Saved Nodes: {saved_node_idxs} + Recomputable Nodes: {recomputable_node_idxs} + """ + torch._logging.trace_structured( + name="artifact", + payload_fn=lambda: knapsack_summary, + ) + log.info(knapsack_summary) dont_ban = set() for idx in recomputable_node_idxs: # if idx in all_recomputable_banned_nodes: