Stop proxy-ing autograd.Function.ctx into the graph (#152621)

The reason why we did this before is because that's how our older
autograd.Function x Dynamo interaction work, but we've since adopted
newer designs that don't actually need the autograd.Function.ctx proxied
into the graph.

We still need a fx.Proxy for the autograd.Function.ctx object, so
whenever we do I create one via discard_graph_changes.

Test Plan:
- existing tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152621
Approved by: https://github.com/oulgen
This commit is contained in:
rzou 2025-05-06 11:56:37 -07:00 committed by PyTorch MergeBot
parent 22c31046d4
commit 2926dd4d8e
7 changed files with 13 additions and 18 deletions

View File

@ -598,7 +598,6 @@ class GraphModule(torch.nn.Module):
l_weird_b = L_weird_b
l_weird_c = L_weird_c
function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply: "f32[]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None
@ -1120,7 +1119,6 @@ class GraphModule(torch.nn.Module):
l_x_ = L_x_
l_weight_ = L_weight_
function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply: "f32[5, 4]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_weight_, args_tensor_mask = [True, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_weight_ = None
@ -1305,7 +1303,6 @@ class GraphModule(torch.nn.Module):
l_x_ = L_x_
l_y_ = L_y_
function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, args_tensor_mask = [True, True], non_differentiable_idx = [1]); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None
@ -1474,7 +1471,7 @@ class GraphModule(torch.nn.Module):
self.assertEqual(out, x + 1)
self.assertEqual(x.grad.shape, shape)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 2)
self.assertEqual(cnt.op_count, 1)
@requires_cuda
def test_triton_kernel_basic(self):

View File

@ -1246,13 +1246,13 @@ class ReproTests(torch._dynamo.test_case.TestCase):
with torch.no_grad():
cnt = self._reformer(nopython=True)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 11)
self.assertEqual(cnt.op_count, 10)
def test_reformer_train(self):
with torch.enable_grad():
cnt = self._reformer(nopython=False)
expected_op_count = (
"""11""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5"""
"""10""" if torch._dynamo.config.inline_inbuilt_nn_modules else """4"""
)
self.assertExpectedInline(cnt.frame_count, """1""")
@ -3708,7 +3708,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
expected = fn(*inputs1)
actual = fn_opt(*inputs2)
self.assertTrue(same(actual, expected))
self.assertEqual(cnt.op_count, 2)
self.assertEqual(cnt.op_count, 1)
self.assertEqual(cnt.frame_count, 1)
cnt.clear()
counters.clear()

View File

@ -7197,7 +7197,6 @@ FAILING_CACHE_TESTS = (
# BypassAOTAutogradCache: unsupported nodes
"test_backward_mutation_data", # Custom Autograd Function
"test_backward_mutation_metadata", # Custom Autograd Function
"test_custom_autograd", # Custom Autograd Function
"test_input_output_aliase_custom_autograd_function",
)

View File

@ -1504,7 +1504,6 @@ class GraphModule(torch.nn.Module):
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8, 8]"):
function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply: "f32[8, 8]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, args_tensor_mask = [True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = None

View File

@ -2848,6 +2848,15 @@ class AutogradFunctionApplyVariable(VariableTracker):
)
ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
with discard_graph_changes(tx):
# A little hacky, but we need a dummy ctx proxy for speculate_subgraph.
# We should clean this up at some point.
proxy = tx.output.create_proxy(
"call_function", torch.autograd.function.FunctionCtx, (), {}
)
set_example_value(proxy.node, ctx.value)
ctx.proxy = proxy
if isinstance(self.fwd_graph, types.FunctionType):
fwd_fn = UserFunctionVariable(self.fwd_graph)
fwd_args = [ctx, *args]

View File

@ -53,7 +53,6 @@ from ..utils import (
istype,
list_methods,
proxy_args_kwargs,
set_example_value,
tuple_methods,
)
from .base import VariableTracker
@ -877,7 +876,6 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
value,
value_type=None,
inference=False,
proxy=None,
saved_tensors=None,
needs_input_grad=None,
non_differentiable=None,
@ -885,7 +883,6 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
) -> None:
super().__init__(value=value, value_type=value_type, **kwargs)
self.inference = inference
self.proxy = proxy
self.saved_tensors = saved_tensors
self.needs_input_grad = needs_input_grad
self.non_differentiable = non_differentiable
@ -898,23 +895,17 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
isinstance(x, variables.TensorVariable) and x.requires_grad
for x in args
)
proxy = tx.output.create_proxy(
"call_function", torch.autograd.function.FunctionCtx, (), {}
)
out = tx.output.side_effects.track_object_new(
None,
torch.autograd.function.FunctionCtx,
functools.partial(
AutogradFunctionContextVariable,
inference=True,
proxy=proxy,
saved_tensors=SavedTensorBox(),
needs_input_grad=needs_input_grad,
),
{},
)
set_example_value(proxy.node, out.value)
return out
def as_proxy(self):