From a94e507c39df2d2aa8c2ebb70b018e9dda273307 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 30 Jul 2024 12:21:32 -0700 Subject: [PATCH] [aota] Needs autograd if an input requires_grad, agnostic to enable_grad (#128890) Original issue: https://github.com/pytorch/pytorch/issues/114338 Reland of: https://github.com/pytorch/pytorch/pull/128016 Summary from previous PR: We assume only two possible mutually exclusive scenarios: Running compiled region for training (Any of inputs has requires_grad) Produced differentiable outputs should have requires_grad. Running compiled region for inference (None of inputs has requires_grad) All outputs do not have requires_grad. Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1). With current state that means: 1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad 2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad() Changes in partitioner? Inference and Training graphs had difference in return container, list/tuple. The changes in partitioner are done to unify and return always tuple. As a result - some changes in test_aotdispatch.py for graph contents list -> tuple. Why was revert? There was a regression of hf_Reformer model on inference. ``` TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode ``` Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True). Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad. As a result we started compiling training graph instead of inference. Fix for view ops: If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph. This is handled in aot_autograd.py, where output_and_mutation_safe are calculated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128890 Approved by: https://github.com/bdhirsh --- test/functorch/test_aotdispatch.py | 68 ++++++++++++++++--- .../_aot_autograd/runtime_wrappers.py | 8 ++- torch/_functorch/aot_autograd.py | 17 +++-- 3 files changed, 80 insertions(+), 13 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 2aafe37be87..b774f297167 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1547,8 +1547,8 @@ def forward(self, arg0_1): self.assertExpectedInline( fw_graph.code.strip(), """\ -def forward(self, primals_1): - view = torch.ops.aten.view.default(primals_1, [-1]); primals_1 = None +def forward(self, arg0_1): + view = torch.ops.aten.view.default(arg0_1, [-1]); arg0_1 = None return (view,)""", ) @@ -2199,12 +2199,11 @@ def forward(self, primals_1): self.assertExpectedInline( fw_graph.code.strip(), """\ -def forward(self, primals_1, primals_2): - view = torch.ops.aten.view.default(primals_1, [3, 3]); primals_1 = None - t = torch.ops.aten.t.default(view); view = None - view_1 = torch.ops.aten.view.default(primals_2, [3, 3]); primals_2 = None - view_2 = torch.ops.aten.view.default(t, [3, 3]) - return (t, view_1, view_2)""", +def forward(self, arg0_1, arg1_1): + t = torch.ops.aten.t.default(arg0_1); arg0_1 = None + view = torch.ops.aten.view.default(arg1_1, [3, 3]); arg1_1 = None + view_1 = torch.ops.aten.view.default(t, [3, 3]) + return (t, view, view_1)""", ) def test_view_detach(self): @@ -5744,6 +5743,59 @@ def forward(self, tangents_1, tangents_2): self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) + def test_aot_dispatch_output_requires_grad_in_no_grad(self): + def fn(x): + out1 = x.sin() + with torch.enable_grad(): + out2 = x.cos() + return out1, out2 + + inp_fns = [ + lambda: torch.ones(10, requires_grad=True), + lambda: torch.ones(10, requires_grad=False), + ] + + compiled_f = aot_function(fn, nop) + for inp_fn in inp_fns: + with torch.no_grad(): + ref_x = inp_fn() + ref_out = fn(ref_x) + x = inp_fn() + out = compiled_f(x) + for r, o in zip(ref_out, out): + self.assertEqual(r.requires_grad, o.requires_grad) + if ref_x.requires_grad: + with torch.enable_grad(): + (ref_out[0] + ref_out[1]).sum().backward() + (out[0] + out[1]).sum().backward() + self.assertEqual(ref_x.grad, x.grad) + assert torch.allclose(ref_x.grad, x.grad, atol=1e-3, rtol=1e-3) + + def test_aot_dispatch_output_requires_grad_in_no_grad_views(self): + # view-type ops preserve requires_grad even in no_grad. + def fn(x): + return x.view(-1), x.sin() + + inference_graph_cell = [None] + inference_compiler = make_boxed_compiler( + partial(extract_graph, graph_cell=inference_graph_cell) + ) + compiled_fn = aot_function(fn, nop, inference_compiler=inference_compiler) + + inp_x0 = torch.ones(2, 3, requires_grad=True) + # Clone in no_grad will make requires_grad=False tensors, keep clone outside of no_grad + ref_x0 = inp_x0.clone() + x0 = inp_x0.clone() + with torch.no_grad(): + ref_out1, ref_out2 = fn(ref_x0) + + out1, out2 = compiled_fn(x0) + # Assert that we executed inference graph + self.assertTrue(inference_graph_cell[0] is not None) + + self.assertEqual(ref_out1.requires_grad, out1.requires_grad) + self.assertEqual(ref_out2.requires_grad, out2.requires_grad) + class TestAOTModuleSimplified(AOTTestCase): def test_aot_module_simplified(self): diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 9c93e5d5bec..84de0f4d1e4 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -310,7 +310,13 @@ def _create_runtime_wrapper( for idx in indices_of_inps_to_detach: if isinstance(args_[idx], torch.Tensor): args_[idx] = args_[idx].detach() - with torch.autograd._force_original_view_tracking(True): + + # It's possible to have trace_joint inside user specified with no_grad() region, + # if there is a nested with enable_grad(), that forces some outputs to require gradients. + # Therefore, we unconditionally turn on enable_grad() for compiled_fn execution. + with torch.autograd._force_original_view_tracking( + True + ), torch.enable_grad(): all_outs = call_func_at_runtime_with_args( compiled_fn, args_, disable_amp=disable_amp, steal_args=True ) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 22187e1700d..a1a9c8af48a 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -572,9 +572,8 @@ def create_aot_dispatcher_function( fake_flat_args = process_inputs(flat_args) - needs_autograd = ( - any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)) - and torch.is_grad_enabled() + needs_autograd = any( + x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) ) with enable_python_dispatcher(): @@ -600,7 +599,17 @@ def create_aot_dispatcher_function( ) output_and_mutation_safe = not any( - x.requires_grad for x in fw_metadata.output_info + x.requires_grad + # view-type operations preserve requires_grad even in no_grad. + # Do not count aliases of inputs with requires_grad as reason to make a training graph, + # as AOTAutograd will perform view-replay to regenerate the view outputs at runtime, + # setting their grad_fn properly. + and not ( + x.output_type + in (OutputType.alias_of_input, OutputType.is_input) + and fw_metadata.input_info[x.base_idx].requires_grad + ) + for x in fw_metadata.output_info ) and not any( x.requires_grad and x.mutates_data