[HigherOrderOp] Simplify design by removing reliance on name match (#104350)

Previously:
- we were keeping a list of proxies seen by the current SubgraphTracer.
It turns out, fx.Proxy has a .tracer field that we should be able to use instead.
- we were using name matching to determine if a freevar was already
lifted to being the input of the parent SubgraphTracer. Voz and I have
previously expressed concerns about the robsustness of name matching.

This PR introduces a simplified design with more invariants:
- When doing HigherOrderOp tracing, we may encounter Proxys
- Each Proxy object is associated with a SubgraphTracer.
- The new invariant is that SubgraphTracer should only construct Nodes
using Proxy that come from the SubgraphTracer. This helps us avoid
malformed graphs.
- If the Proxy object came from another SubgraphTracer, then this means
it is a free variable. We need to lift it to being an input of the
current SubgraphTracer, which will result in the construction of a new
Proxy in the current SubgraphTracer. This new Proxy should be used
whenever the old Proxy is seen by the current SubgraphTracer.

Test Plan:
- existing tests + some new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104350
Approved by: https://github.com/ydwu4, https://github.com/voznesenskym
This commit is contained in:
Richard Zou 2023-07-05 18:04:15 -07:00 committed by PyTorch MergeBot
parent 69c4314945
commit adf1405909
3 changed files with 40 additions and 27 deletions

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: dynamo"]
import math
import unittest
import torch
@ -320,6 +321,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
after = compiled_model(*args, **kwargs)
self.assertEqual(before, after)
@unittest.expectedFailure
def test_function_with_bound_free_variable(self):
class LowerBound(torch.autograd.Function):
@staticmethod

View File

@ -240,6 +240,20 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
y = torch.randn(3, 3)
self._test_wrap_simple(f, (x, y), 3)
def test_same_freevar_twice(self):
free = torch.randn(3)
def g(x):
y = free.sin()
z = free.cos()
return y, z
def f(x):
return wrap(g, x)
x = torch.randn(3)
self._test_wrap_simple(f, (x,), 3, 3)
def test_capture_value_created_in_subgraph(self):
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)

View File

@ -1041,16 +1041,19 @@ class SubgraphTracer(fx.Tracer):
# SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design]
self.parent = parent
# A list of proxies that exist in the graph being traced. We use this
# list to determine that, when tracing the body function of a HigherOrderOperator,
# if a new proxy is actually a free variable.
self.seen_proxies = set({})
# A list of previously free variables that we lifted to being inputs of
# the graph. If we are tracing a HigherOrderOperator's body_fn, then we
# need to keep track of this so we can rewrite the HigherOrderOperator
# call using the traced body_fn. This is a OrderedDict (instead of set)
# so that we can maintain the order of args for the HigherOrderOperator
# call. The values are None.
# A dict mapping previously free variables (Proxy objects)
# to new Proxy objects that wrap inputs to this subgraph.
#
# This dict serves two purposes:
# - Proxies are associatd with VariableTrackers. If we see
# the same VariableTracker twice (and it is a free variable),
# then we want to use the same Proxy in the current subgraph to
# record the tracing.
# - If we are tracing a HigherOrderOperator's body_fn, then we
# need to keep track of what free variables were lifted so we can
# rewrite the HigherOrderOperator call using the traced body_fn.
# This is a OrderedDict so that we can
# maintain the order of args for the HigherOrderOperator call.
self.lifted_freevars = collections.OrderedDict()
def create_proxy(
@ -1102,14 +1105,12 @@ class SubgraphTracer(fx.Tracer):
for arg in flat_args:
if not isinstance(arg, torch.fx.Proxy):
new_args.append(arg)
elif arg in self.seen_proxies:
elif arg.tracer == self:
new_args.append(arg)
elif not hasattr(arg, "node"):
new_args.append(arg)
elif "saved_tensor_marked" in arg.node.meta:
new_args.append(arg)
elif arg.node.name in self.input_name_to_proxy:
new_args.append(self.input_name_to_proxy[arg.node.name])
else:
# Create a new input for this arg, and replace the current arg
# with the new arg
@ -1151,7 +1152,6 @@ class SubgraphTracer(fx.Tracer):
msgs = traceback.StackSummary.from_list(frame_summaries).format() # type: ignore[arg-type]
rv.node.stack_trace = "".join(msgs)
self.seen_proxies.add(rv)
return rv
def create_node(self, *args, **kwargs):
@ -1211,25 +1211,22 @@ class SubgraphTracer(fx.Tracer):
self.input_name_to_proxy[name] = proxy
return proxy
def is_name_bound(self, name):
if name in self.input_name_to_proxy:
return True
for proxy in self.seen_proxies:
if proxy.node.name == name:
return True
return False
# See NOTE: [Nested SubgraphTracer and free_variable handling] for more details
def lift_tracked_freevar_to_input(self, proxy):
# You're doing something wrong if we are the root SubgraphTracer because
# Dynamo adds tensors to graph inputs before creating a proxy for them.
assert self.parent is not None or not self.is_name_bound(
proxy.node.name
), "lift_tracked_freevar_to_input on root SubgraphTracer should only be called with non-free variables."
assert (
self.parent is not None
), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer"
# Proxys are associated with VariableTracker.
# It is possible that we've already lifted the Proxy to be an input.
# If that is the case, just return the already lifted Proxy.
if proxy in self.lifted_freevars:
return self.lifted_freevars[proxy]
new_proxy = self.create_graph_input(proxy.node.name)
new_proxy.node.meta["example_value"] = proxy.node.meta["example_value"]
self.lifted_freevars[proxy] = None
if self.parent is not None and not self.parent.is_name_bound(proxy.node.name):
self.lifted_freevars[proxy] = new_proxy
if self.parent is not None and proxy.tracer != self.parent:
self.parent.lift_tracked_freevar_to_input(proxy)
return new_proxy