From 2926dd4d8e08ff767aec590df5e15b0c008ced3b Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 6 May 2025 11:56:37 -0700 Subject: [PATCH] Stop proxy-ing autograd.Function.ctx into the graph (#152621) The reason why we did this before is because that's how our older autograd.Function x Dynamo interaction work, but we've since adopted newer designs that don't actually need the autograd.Function.ctx proxied into the graph. We still need a fx.Proxy for the autograd.Function.ctx object, so whenever we do I create one via discard_graph_changes. Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/152621 Approved by: https://github.com/oulgen --- test/dynamo/test_autograd_function.py | 5 +---- test/dynamo/test_repros.py | 6 +++--- .../TestAutograd.test_custom_function_saved_tensors | 0 test/functorch/test_aotdispatch.py | 1 - test/higher_order_ops/test_invoke_subgraph.py | 1 - torch/_dynamo/variables/higher_order_ops.py | 9 +++++++++ torch/_dynamo/variables/misc.py | 9 --------- 7 files changed, 13 insertions(+), 18 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestAutograd.test_custom_function_saved_tensors diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index a7405cf7bac..80aa5c1025f 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -598,7 +598,6 @@ class GraphModule(torch.nn.Module): l_weird_b = L_weird_b l_weird_c = L_weird_c - function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None fwd_body_0 = self.fwd_body_0 bwd_body_0 = self.bwd_body_0 autograd_function_apply: "f32[]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None @@ -1120,7 +1119,6 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_weight_ = L_weight_ - function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None fwd_body_0 = self.fwd_body_0 bwd_body_0 = self.bwd_body_0 autograd_function_apply: "f32[5, 4]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_weight_, args_tensor_mask = [True, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_weight_ = None @@ -1305,7 +1303,6 @@ class GraphModule(torch.nn.Module): l_x_ = L_x_ l_y_ = L_y_ - function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None fwd_body_0 = self.fwd_body_0 bwd_body_0 = self.bwd_body_0 autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, args_tensor_mask = [True, True], non_differentiable_idx = [1]); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None @@ -1474,7 +1471,7 @@ class GraphModule(torch.nn.Module): self.assertEqual(out, x + 1) self.assertEqual(x.grad.shape, shape) self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 2) + self.assertEqual(cnt.op_count, 1) @requires_cuda def test_triton_kernel_basic(self): diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 87cb2b98ee6..4e7147ddf18 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1246,13 +1246,13 @@ class ReproTests(torch._dynamo.test_case.TestCase): with torch.no_grad(): cnt = self._reformer(nopython=True) self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 11) + self.assertEqual(cnt.op_count, 10) def test_reformer_train(self): with torch.enable_grad(): cnt = self._reformer(nopython=False) expected_op_count = ( - """11""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5""" + """10""" if torch._dynamo.config.inline_inbuilt_nn_modules else """4""" ) self.assertExpectedInline(cnt.frame_count, """1""") @@ -3708,7 +3708,7 @@ class ReproTests(torch._dynamo.test_case.TestCase): expected = fn(*inputs1) actual = fn_opt(*inputs2) self.assertTrue(same(actual, expected)) - self.assertEqual(cnt.op_count, 2) + self.assertEqual(cnt.op_count, 1) self.assertEqual(cnt.frame_count, 1) cnt.clear() counters.clear() diff --git a/test/dynamo_expected_failures/TestAutograd.test_custom_function_saved_tensors b/test/dynamo_expected_failures/TestAutograd.test_custom_function_saved_tensors deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 316dc91b8c9..6c1b4af9cd7 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -7197,7 +7197,6 @@ FAILING_CACHE_TESTS = ( # BypassAOTAutogradCache: unsupported nodes "test_backward_mutation_data", # Custom Autograd Function "test_backward_mutation_metadata", # Custom Autograd Function - "test_custom_autograd", # Custom Autograd Function "test_input_output_aliase_custom_autograd_function", ) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index a2bdee6cdbd..c8e5efb715d 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -1504,7 +1504,6 @@ class GraphModule(torch.nn.Module): class subgraph_0(torch.nn.Module): def forward(self, l_x_: "f32[8, 8]"): - function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None fwd_body_0 = self.fwd_body_0 bwd_body_0 = self.bwd_body_0 autograd_function_apply: "f32[8, 8]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, args_tensor_mask = [True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = None diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 1bcd0cf7974..47557124ffa 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -2848,6 +2848,15 @@ class AutogradFunctionApplyVariable(VariableTracker): ) ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) + with discard_graph_changes(tx): + # A little hacky, but we need a dummy ctx proxy for speculate_subgraph. + # We should clean this up at some point. + proxy = tx.output.create_proxy( + "call_function", torch.autograd.function.FunctionCtx, (), {} + ) + set_example_value(proxy.node, ctx.value) + ctx.proxy = proxy + if isinstance(self.fwd_graph, types.FunctionType): fwd_fn = UserFunctionVariable(self.fwd_graph) fwd_args = [ctx, *args] diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 3a83254f955..9228f4d83cd 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -53,7 +53,6 @@ from ..utils import ( istype, list_methods, proxy_args_kwargs, - set_example_value, tuple_methods, ) from .base import VariableTracker @@ -877,7 +876,6 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable): value, value_type=None, inference=False, - proxy=None, saved_tensors=None, needs_input_grad=None, non_differentiable=None, @@ -885,7 +883,6 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable): ) -> None: super().__init__(value=value, value_type=value_type, **kwargs) self.inference = inference - self.proxy = proxy self.saved_tensors = saved_tensors self.needs_input_grad = needs_input_grad self.non_differentiable = non_differentiable @@ -898,23 +895,17 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable): isinstance(x, variables.TensorVariable) and x.requires_grad for x in args ) - proxy = tx.output.create_proxy( - "call_function", torch.autograd.function.FunctionCtx, (), {} - ) out = tx.output.side_effects.track_object_new( None, torch.autograd.function.FunctionCtx, functools.partial( AutogradFunctionContextVariable, inference=True, - proxy=proxy, saved_tensors=SavedTensorBox(), needs_input_grad=needs_input_grad, ), {}, ) - set_example_value(proxy.node, out.value) - return out def as_proxy(self):