diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 2aafe37be87..b774f297167 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1547,8 +1547,8 @@ def forward(self, arg0_1): self.assertExpectedInline( fw_graph.code.strip(), """\ -def forward(self, primals_1): - view = torch.ops.aten.view.default(primals_1, [-1]); primals_1 = None +def forward(self, arg0_1): + view = torch.ops.aten.view.default(arg0_1, [-1]); arg0_1 = None return (view,)""", ) @@ -2199,12 +2199,11 @@ def forward(self, primals_1): self.assertExpectedInline( fw_graph.code.strip(), """\ -def forward(self, primals_1, primals_2): - view = torch.ops.aten.view.default(primals_1, [3, 3]); primals_1 = None - t = torch.ops.aten.t.default(view); view = None - view_1 = torch.ops.aten.view.default(primals_2, [3, 3]); primals_2 = None - view_2 = torch.ops.aten.view.default(t, [3, 3]) - return (t, view_1, view_2)""", +def forward(self, arg0_1, arg1_1): + t = torch.ops.aten.t.default(arg0_1); arg0_1 = None + view = torch.ops.aten.view.default(arg1_1, [3, 3]); arg1_1 = None + view_1 = torch.ops.aten.view.default(t, [3, 3]) + return (t, view, view_1)""", ) def test_view_detach(self): @@ -5744,6 +5743,59 @@ def forward(self, tangents_1, tangents_2): self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) + def test_aot_dispatch_output_requires_grad_in_no_grad(self): + def fn(x): + out1 = x.sin() + with torch.enable_grad(): + out2 = x.cos() + return out1, out2 + + inp_fns = [ + lambda: torch.ones(10, requires_grad=True), + lambda: torch.ones(10, requires_grad=False), + ] + + compiled_f = aot_function(fn, nop) + for inp_fn in inp_fns: + with torch.no_grad(): + ref_x = inp_fn() + ref_out = fn(ref_x) + x = inp_fn() + out = compiled_f(x) + for r, o in zip(ref_out, out): + self.assertEqual(r.requires_grad, o.requires_grad) + if ref_x.requires_grad: + with torch.enable_grad(): + (ref_out[0] + ref_out[1]).sum().backward() + (out[0] + out[1]).sum().backward() + self.assertEqual(ref_x.grad, x.grad) + assert torch.allclose(ref_x.grad, x.grad, atol=1e-3, rtol=1e-3) + + def test_aot_dispatch_output_requires_grad_in_no_grad_views(self): + # view-type ops preserve requires_grad even in no_grad. + def fn(x): + return x.view(-1), x.sin() + + inference_graph_cell = [None] + inference_compiler = make_boxed_compiler( + partial(extract_graph, graph_cell=inference_graph_cell) + ) + compiled_fn = aot_function(fn, nop, inference_compiler=inference_compiler) + + inp_x0 = torch.ones(2, 3, requires_grad=True) + # Clone in no_grad will make requires_grad=False tensors, keep clone outside of no_grad + ref_x0 = inp_x0.clone() + x0 = inp_x0.clone() + with torch.no_grad(): + ref_out1, ref_out2 = fn(ref_x0) + + out1, out2 = compiled_fn(x0) + # Assert that we executed inference graph + self.assertTrue(inference_graph_cell[0] is not None) + + self.assertEqual(ref_out1.requires_grad, out1.requires_grad) + self.assertEqual(ref_out2.requires_grad, out2.requires_grad) + class TestAOTModuleSimplified(AOTTestCase): def test_aot_module_simplified(self): diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 9c93e5d5bec..84de0f4d1e4 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -310,7 +310,13 @@ def _create_runtime_wrapper( for idx in indices_of_inps_to_detach: if isinstance(args_[idx], torch.Tensor): args_[idx] = args_[idx].detach() - with torch.autograd._force_original_view_tracking(True): + + # It's possible to have trace_joint inside user specified with no_grad() region, + # if there is a nested with enable_grad(), that forces some outputs to require gradients. + # Therefore, we unconditionally turn on enable_grad() for compiled_fn execution. + with torch.autograd._force_original_view_tracking( + True + ), torch.enable_grad(): all_outs = call_func_at_runtime_with_args( compiled_fn, args_, disable_amp=disable_amp, steal_args=True ) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 22187e1700d..a1a9c8af48a 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -572,9 +572,8 @@ def create_aot_dispatcher_function( fake_flat_args = process_inputs(flat_args) - needs_autograd = ( - any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)) - and torch.is_grad_enabled() + needs_autograd = any( + x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) ) with enable_python_dispatcher(): @@ -600,7 +599,17 @@ def create_aot_dispatcher_function( ) output_and_mutation_safe = not any( - x.requires_grad for x in fw_metadata.output_info + x.requires_grad + # view-type operations preserve requires_grad even in no_grad. + # Do not count aliases of inputs with requires_grad as reason to make a training graph, + # as AOTAutograd will perform view-replay to regenerate the view outputs at runtime, + # setting their grad_fn properly. + and not ( + x.output_type + in (OutputType.alias_of_input, OutputType.is_input) + and fw_metadata.input_info[x.base_idx].requires_grad + ) + for x in fw_metadata.output_info ) and not any( x.requires_grad and x.mutates_data