[FX] Fix uses not updating when erasing a node (#47720)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47720

Test Plan: Imported from OSS

Reviewed By: zdevito

Differential Revision: D24875880

Pulled By: jamesr66a

fbshipit-source-id: aae9ffd10f8085b599e7923152287c6e6950ff49
This commit is contained in:
James Reed 2020-11-11 10:54:01 -08:00 committed by Facebook GitHub Bot
parent d1351c66a8
commit dbfee42a7d
2 changed files with 22 additions and 0 deletions

View File

@ -700,6 +700,19 @@ class TestFX(JitTestCase):
ref = torch.sin(mod.linear(input) + mod.bias)
self.assertEqual(r, ref)
def test_remove_uses(self):
g : torch.fx.Graph = Graph()
x : torch.fx.Node = g.placeholder('x')
relu : torch.fx.Node = g.call_function(torch.relu, (x,))
neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
g.output(neg)
neg.replace_all_uses_with(relu)
g.erase_node(neg)
self.assertTrue(neg not in relu.users)
def test_construct_root_dict(self):
graph : torch.fx.Graph = torch.fx.Graph()
a : torch.fx.Node = graph.create_node('placeholder', 'x')

View File

@ -261,6 +261,15 @@ class Graph:
to_erase._erased = True # iterators may retain handles to erased nodes
self._len -= 1
# Null out this Node's argument nodes so that the Nodes referred to
# can update their `users` accordingly
new_args = map_arg(to_erase.args, lambda n: None)
assert isinstance(new_args, tuple)
to_erase.args = new_args
new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
assert isinstance(new_kwargs, dict)
to_erase.kwargs = new_kwargs
def inserting_before(self, n: Optional[Node] = None):
"""Set the point at which create_node and companion methods will insert into the graph.
When used within a 'with' statement, this will temporary set the insert point and