mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[HigherOrderOp] Simplify design by removing reliance on name match (#104350)
Previously: - we were keeping a list of proxies seen by the current SubgraphTracer. It turns out, fx.Proxy has a .tracer field that we should be able to use instead. - we were using name matching to determine if a freevar was already lifted to being the input of the parent SubgraphTracer. Voz and I have previously expressed concerns about the robsustness of name matching. This PR introduces a simplified design with more invariants: - When doing HigherOrderOp tracing, we may encounter Proxys - Each Proxy object is associated with a SubgraphTracer. - The new invariant is that SubgraphTracer should only construct Nodes using Proxy that come from the SubgraphTracer. This helps us avoid malformed graphs. - If the Proxy object came from another SubgraphTracer, then this means it is a free variable. We need to lift it to being an input of the current SubgraphTracer, which will result in the construction of a new Proxy in the current SubgraphTracer. This new Proxy should be used whenever the old Proxy is seen by the current SubgraphTracer. Test Plan: - existing tests + some new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/104350 Approved by: https://github.com/ydwu4, https://github.com/voznesenskym
This commit is contained in:
parent
69c4314945
commit
adf1405909
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -320,6 +321,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||
after = compiled_model(*args, **kwargs)
|
||||
self.assertEqual(before, after)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_function_with_bound_free_variable(self):
|
||||
class LowerBound(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -240,6 +240,20 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
|||
y = torch.randn(3, 3)
|
||||
self._test_wrap_simple(f, (x, y), 3)
|
||||
|
||||
def test_same_freevar_twice(self):
|
||||
free = torch.randn(3)
|
||||
|
||||
def g(x):
|
||||
y = free.sin()
|
||||
z = free.cos()
|
||||
return y, z
|
||||
|
||||
def f(x):
|
||||
return wrap(g, x)
|
||||
|
||||
x = torch.randn(3)
|
||||
self._test_wrap_simple(f, (x,), 3, 3)
|
||||
|
||||
def test_capture_value_created_in_subgraph(self):
|
||||
backend = EagerAndRecordGraphs()
|
||||
cnt = CompileCounterWithBackend(backend)
|
||||
|
|
|
|||
|
|
@ -1041,16 +1041,19 @@ class SubgraphTracer(fx.Tracer):
|
|||
|
||||
# SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design]
|
||||
self.parent = parent
|
||||
# A list of proxies that exist in the graph being traced. We use this
|
||||
# list to determine that, when tracing the body function of a HigherOrderOperator,
|
||||
# if a new proxy is actually a free variable.
|
||||
self.seen_proxies = set({})
|
||||
# A list of previously free variables that we lifted to being inputs of
|
||||
# the graph. If we are tracing a HigherOrderOperator's body_fn, then we
|
||||
# need to keep track of this so we can rewrite the HigherOrderOperator
|
||||
# call using the traced body_fn. This is a OrderedDict (instead of set)
|
||||
# so that we can maintain the order of args for the HigherOrderOperator
|
||||
# call. The values are None.
|
||||
# A dict mapping previously free variables (Proxy objects)
|
||||
# to new Proxy objects that wrap inputs to this subgraph.
|
||||
#
|
||||
# This dict serves two purposes:
|
||||
# - Proxies are associatd with VariableTrackers. If we see
|
||||
# the same VariableTracker twice (and it is a free variable),
|
||||
# then we want to use the same Proxy in the current subgraph to
|
||||
# record the tracing.
|
||||
# - If we are tracing a HigherOrderOperator's body_fn, then we
|
||||
# need to keep track of what free variables were lifted so we can
|
||||
# rewrite the HigherOrderOperator call using the traced body_fn.
|
||||
# This is a OrderedDict so that we can
|
||||
# maintain the order of args for the HigherOrderOperator call.
|
||||
self.lifted_freevars = collections.OrderedDict()
|
||||
|
||||
def create_proxy(
|
||||
|
|
@ -1102,14 +1105,12 @@ class SubgraphTracer(fx.Tracer):
|
|||
for arg in flat_args:
|
||||
if not isinstance(arg, torch.fx.Proxy):
|
||||
new_args.append(arg)
|
||||
elif arg in self.seen_proxies:
|
||||
elif arg.tracer == self:
|
||||
new_args.append(arg)
|
||||
elif not hasattr(arg, "node"):
|
||||
new_args.append(arg)
|
||||
elif "saved_tensor_marked" in arg.node.meta:
|
||||
new_args.append(arg)
|
||||
elif arg.node.name in self.input_name_to_proxy:
|
||||
new_args.append(self.input_name_to_proxy[arg.node.name])
|
||||
else:
|
||||
# Create a new input for this arg, and replace the current arg
|
||||
# with the new arg
|
||||
|
|
@ -1151,7 +1152,6 @@ class SubgraphTracer(fx.Tracer):
|
|||
msgs = traceback.StackSummary.from_list(frame_summaries).format() # type: ignore[arg-type]
|
||||
rv.node.stack_trace = "".join(msgs)
|
||||
|
||||
self.seen_proxies.add(rv)
|
||||
return rv
|
||||
|
||||
def create_node(self, *args, **kwargs):
|
||||
|
|
@ -1211,25 +1211,22 @@ class SubgraphTracer(fx.Tracer):
|
|||
self.input_name_to_proxy[name] = proxy
|
||||
return proxy
|
||||
|
||||
def is_name_bound(self, name):
|
||||
if name in self.input_name_to_proxy:
|
||||
return True
|
||||
for proxy in self.seen_proxies:
|
||||
if proxy.node.name == name:
|
||||
return True
|
||||
return False
|
||||
|
||||
# See NOTE: [Nested SubgraphTracer and free_variable handling] for more details
|
||||
def lift_tracked_freevar_to_input(self, proxy):
|
||||
# You're doing something wrong if we are the root SubgraphTracer because
|
||||
# Dynamo adds tensors to graph inputs before creating a proxy for them.
|
||||
assert self.parent is not None or not self.is_name_bound(
|
||||
proxy.node.name
|
||||
), "lift_tracked_freevar_to_input on root SubgraphTracer should only be called with non-free variables."
|
||||
assert (
|
||||
self.parent is not None
|
||||
), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer"
|
||||
# Proxys are associated with VariableTracker.
|
||||
# It is possible that we've already lifted the Proxy to be an input.
|
||||
# If that is the case, just return the already lifted Proxy.
|
||||
if proxy in self.lifted_freevars:
|
||||
return self.lifted_freevars[proxy]
|
||||
new_proxy = self.create_graph_input(proxy.node.name)
|
||||
new_proxy.node.meta["example_value"] = proxy.node.meta["example_value"]
|
||||
self.lifted_freevars[proxy] = None
|
||||
if self.parent is not None and not self.parent.is_name_bound(proxy.node.name):
|
||||
self.lifted_freevars[proxy] = new_proxy
|
||||
if self.parent is not None and proxy.tracer != self.parent:
|
||||
self.parent.lift_tracked_freevar_to_input(proxy)
|
||||
return new_proxy
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user