mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Stop proxy-ing autograd.Function.ctx into the graph (#152621)
The reason why we did this before is because that's how our older autograd.Function x Dynamo interaction work, but we've since adopted newer designs that don't actually need the autograd.Function.ctx proxied into the graph. We still need a fx.Proxy for the autograd.Function.ctx object, so whenever we do I create one via discard_graph_changes. Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/152621 Approved by: https://github.com/oulgen
This commit is contained in:
parent
22c31046d4
commit
2926dd4d8e
|
|
@ -598,7 +598,6 @@ class GraphModule(torch.nn.Module):
|
|||
l_weird_b = L_weird_b
|
||||
l_weird_c = L_weird_c
|
||||
|
||||
function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None
|
||||
fwd_body_0 = self.fwd_body_0
|
||||
bwd_body_0 = self.bwd_body_0
|
||||
autograd_function_apply: "f32[]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None
|
||||
|
|
@ -1120,7 +1119,6 @@ class GraphModule(torch.nn.Module):
|
|||
l_x_ = L_x_
|
||||
l_weight_ = L_weight_
|
||||
|
||||
function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None
|
||||
fwd_body_0 = self.fwd_body_0
|
||||
bwd_body_0 = self.bwd_body_0
|
||||
autograd_function_apply: "f32[5, 4]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_weight_, args_tensor_mask = [True, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_weight_ = None
|
||||
|
|
@ -1305,7 +1303,6 @@ class GraphModule(torch.nn.Module):
|
|||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None
|
||||
fwd_body_0 = self.fwd_body_0
|
||||
bwd_body_0 = self.bwd_body_0
|
||||
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, args_tensor_mask = [True, True], non_differentiable_idx = [1]); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None
|
||||
|
|
@ -1474,7 +1471,7 @@ class GraphModule(torch.nn.Module):
|
|||
self.assertEqual(out, x + 1)
|
||||
self.assertEqual(x.grad.shape, shape)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, 2)
|
||||
self.assertEqual(cnt.op_count, 1)
|
||||
|
||||
@requires_cuda
|
||||
def test_triton_kernel_basic(self):
|
||||
|
|
|
|||
|
|
@ -1246,13 +1246,13 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
with torch.no_grad():
|
||||
cnt = self._reformer(nopython=True)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, 11)
|
||||
self.assertEqual(cnt.op_count, 10)
|
||||
|
||||
def test_reformer_train(self):
|
||||
with torch.enable_grad():
|
||||
cnt = self._reformer(nopython=False)
|
||||
expected_op_count = (
|
||||
"""11""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5"""
|
||||
"""10""" if torch._dynamo.config.inline_inbuilt_nn_modules else """4"""
|
||||
)
|
||||
|
||||
self.assertExpectedInline(cnt.frame_count, """1""")
|
||||
|
|
@ -3708,7 +3708,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
expected = fn(*inputs1)
|
||||
actual = fn_opt(*inputs2)
|
||||
self.assertTrue(same(actual, expected))
|
||||
self.assertEqual(cnt.op_count, 2)
|
||||
self.assertEqual(cnt.op_count, 1)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
cnt.clear()
|
||||
counters.clear()
|
||||
|
|
|
|||
|
|
@ -7197,7 +7197,6 @@ FAILING_CACHE_TESTS = (
|
|||
# BypassAOTAutogradCache: unsupported nodes
|
||||
"test_backward_mutation_data", # Custom Autograd Function
|
||||
"test_backward_mutation_metadata", # Custom Autograd Function
|
||||
"test_custom_autograd", # Custom Autograd Function
|
||||
"test_input_output_aliase_custom_autograd_function",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1504,7 +1504,6 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8, 8]"):
|
||||
function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None
|
||||
fwd_body_0 = self.fwd_body_0
|
||||
bwd_body_0 = self.bwd_body_0
|
||||
autograd_function_apply: "f32[8, 8]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, args_tensor_mask = [True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = None
|
||||
|
|
|
|||
|
|
@ -2848,6 +2848,15 @@ class AutogradFunctionApplyVariable(VariableTracker):
|
|||
)
|
||||
|
||||
ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
|
||||
with discard_graph_changes(tx):
|
||||
# A little hacky, but we need a dummy ctx proxy for speculate_subgraph.
|
||||
# We should clean this up at some point.
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function", torch.autograd.function.FunctionCtx, (), {}
|
||||
)
|
||||
set_example_value(proxy.node, ctx.value)
|
||||
ctx.proxy = proxy
|
||||
|
||||
if isinstance(self.fwd_graph, types.FunctionType):
|
||||
fwd_fn = UserFunctionVariable(self.fwd_graph)
|
||||
fwd_args = [ctx, *args]
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@ from ..utils import (
|
|||
istype,
|
||||
list_methods,
|
||||
proxy_args_kwargs,
|
||||
set_example_value,
|
||||
tuple_methods,
|
||||
)
|
||||
from .base import VariableTracker
|
||||
|
|
@ -877,7 +876,6 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
|
|||
value,
|
||||
value_type=None,
|
||||
inference=False,
|
||||
proxy=None,
|
||||
saved_tensors=None,
|
||||
needs_input_grad=None,
|
||||
non_differentiable=None,
|
||||
|
|
@ -885,7 +883,6 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
|
|||
) -> None:
|
||||
super().__init__(value=value, value_type=value_type, **kwargs)
|
||||
self.inference = inference
|
||||
self.proxy = proxy
|
||||
self.saved_tensors = saved_tensors
|
||||
self.needs_input_grad = needs_input_grad
|
||||
self.non_differentiable = non_differentiable
|
||||
|
|
@ -898,23 +895,17 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
|
|||
isinstance(x, variables.TensorVariable) and x.requires_grad
|
||||
for x in args
|
||||
)
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function", torch.autograd.function.FunctionCtx, (), {}
|
||||
)
|
||||
out = tx.output.side_effects.track_object_new(
|
||||
None,
|
||||
torch.autograd.function.FunctionCtx,
|
||||
functools.partial(
|
||||
AutogradFunctionContextVariable,
|
||||
inference=True,
|
||||
proxy=proxy,
|
||||
saved_tensors=SavedTensorBox(),
|
||||
needs_input_grad=needs_input_grad,
|
||||
),
|
||||
{},
|
||||
)
|
||||
set_example_value(proxy.node, out.value)
|
||||
|
||||
return out
|
||||
|
||||
def as_proxy(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user