From 0b11da0ccb2bc4ba91bd5159bcb429ceed9a783e Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 21 Jul 2023 17:03:02 -0700 Subject: [PATCH] [partitioners][ac][dynamic] Fix output signature of fwd with symints (#105771) Pull Request resolved: https://github.com/pytorch/pytorch/pull/105771 Approved by: https://github.com/Chillee --- test/dynamo/test_activation_checkpointing.py | 30 +++++++++++++++++++ torch/_functorch/partitioners.py | 31 +++++++++++++------- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 4ea3aeefeef..98a5d6e6b38 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -331,6 +331,36 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): body_function = getattr(cnt.graphs[0], wrap_node.args[0].name) self.assertEqual(op_count(body_function), 2) + @requires_cuda() + def test_symints_location(self): + def gn(x, y): + return torch.matmul(x, torch.nn.functional.dropout(y, 0.5)) + + def fn(x, y): + return torch.utils.checkpoint.checkpoint(gn, x, y) + + backend = "aot_eager" + cnt = CompileCounterWithBackend(backend) + opt_fn = torch.compile(fn, backend=cnt) + + x = torch.randn(4, 4, requires_grad=True) + y = torch.randn(4, 4, requires_grad=True) + args = (x, y) + expected = fn(*args) + result = opt_fn(*args) + + x = torch.randn(5, 5, requires_grad=True) + y = torch.randn(5, 5, requires_grad=True) + args = (x, y) + expected = fn(*args) + result = opt_fn(*args) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(cnt.frame_count, 2) + self.assertEqual(len(cnt.graphs), 2) + wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint) + self.assertEqual(len(wrap_node.args), 3) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index b774c80a61d..f8dbae59e14 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -130,7 +130,7 @@ def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs): return fwd_outputs, bwd_outputs -def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes=(), *, num_fwd_outputs): +def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes, *, num_fwd_outputs): fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes)) @@ -199,9 +199,11 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_s saved_symbols |= new_symbols - # Update saved_sym_nodes that are now reordered to have all bindings - # at front - saved_sym_nodes = saved_sym_nodes_binding + saved_sym_nodes_derived + # Update saved_sym_nodes that are now reordered to have all bindings at + # front. This can also be used later on to figure out the position of saved + # sym nodes in the output of fwd graph. + saved_sym_nodes.clear() + saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived) # Now, we re-generate the fwd/bwd graphs. # NB: This might increase compilation time, but I doubt it matters @@ -480,7 +482,7 @@ def reordering_to_mimic_autograd_engine(gm): return new_gm -def functionalize_rng_ops(joint_module, fw_module, bw_module): +def functionalize_rng_ops(joint_module, fw_module, bw_module, num_sym_nodes): # During user-driven activation checkpointing, we have to ensure that a rng # op in fwd yields the same output as the recomputed rng op in the bwd. To # do this, we use functionalize wrappers to wrap the random ops and share @@ -491,7 +493,9 @@ def functionalize_rng_ops(joint_module, fw_module, bw_module): # Step 2 - Modify the fwd pass such that # 1) Replace rand with run_and_save_rng_state wrapper # 2) Replace the users of the original op with the output[1] of this op. - # 3) Collect all the rng_state - output[0] of each op, and make them output nodes. + # 3) Collect all the rng_state - output[0] of each op, and make them + # output nodes. Special care needs to be taken here because fwd outputs + # has symints at the very end. # Step 3 - Modify the bwd pass such that # 1) Add the input nodes just before the tangents for the stashed rng states # 2) Replace rand with run_with_save_rng_state wrappers @@ -574,11 +578,15 @@ def functionalize_rng_ops(joint_module, fw_module, bw_module): bw_graph.erase_node(bw_node) - # Add the rng states in the output of the fwd graph - fw_output = [node for node in fw_module.graph.nodes if node.op == "output"][0] - outputs = fw_output.args[0] + fw_rng_state_outputs + # Add the rng states in the output of the fwd graph. AOT Autograd assumes + # that symints are at the end of forward graph outputs. So, insert the new + # rng states accordingly. + fw_output_node = [node for node in fw_module.graph.nodes if node.op == "output"][0] + fw_outputs = fw_output_node.args[0] + sym_node_start_idx = len(fw_outputs) - num_sym_nodes + outputs = fw_outputs[:sym_node_start_idx] + fw_rng_state_outputs + fw_outputs[sym_node_start_idx:] fw_module.graph.output(outputs) - fw_module.graph.erase_node(fw_output) + fw_module.graph.erase_node(fw_output_node) fw_module.recompile() bw_module.recompile() return fw_module, bw_module @@ -866,13 +874,14 @@ def min_cut_rematerialization_partition( # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes = list(filter(lambda n: is_sym_node(n), saved_values)) saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) + # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols fw_module, bw_module = _extract_fwd_bwd_modules( joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs) if graph_has_recomputable_ops: if graph_has_recomputable_rng_ops: fw_module, bw_module = functionalize_rng_ops( - joint_module, fw_module, bw_module + joint_module, fw_module, bw_module, len(saved_sym_nodes) ) bw_module = reordering_to_mimic_autograd_engine(bw_module)