reinplace pass: bugfix for output node replacement (#83845)

Cleaned up some of the arg replacement logic to use tree_map, so it handles FX nodes that have nested containers.

See the added test: when you write a function that returns a list, the `output` node in the FX graph shows up as having `node.args = tuple(immutable_list(...))`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83845
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh 2022-08-22 08:39:25 -07:00 committed by PyTorch MergeBot
parent 01434c2d20
commit 75ec7b7547
2 changed files with 39 additions and 14 deletions

View File

@ -295,5 +295,32 @@ def forward(self, a__1):
return select_scatter_default
""") # noqa: B950
def test_out_node_updated(self):
def f():
x = torch.zeros(2, 2)
y = x.diagonal()
y_updated = y.add(1)
z = torch.diagonal_scatter(x, y_updated)
# reinplace needs to know to replace output [z] with [x]
return [z]
if not HAS_FUNCTIONALIZATION:
return
f2 = reinplace(make_fx(functionalize(f))())
expected_out = f()
actual_out = f2()
self.assertEqual(actual_out, expected_out)
self.assertExpectedInline(f2.code, """\
def forward(self):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_default = torch.ops.aten.diagonal.default(zeros)
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, 1); diagonal_default = None
return [zeros]
""")
if __name__ == '__main__':
run_tests()

View File

@ -2,7 +2,7 @@ import torch
from torch.fx import Node
from torch.fx._compatibility import compatibility
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
from torch.utils._pytree import tree_map, tree_flatten
from torch.utils._pytree import tree_map, tree_flatten, tree_map_only
from torch.multiprocessing.reductions import StorageWeakRef
import _operator
@ -526,20 +526,18 @@ def reinplace(gm, *sample_args):
nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
for node_to_update in nodes_to_update:
new_args = []
for arg_idx, a in enumerate(node_to_update.args):
if a == old:
new_args.append(new)
else:
new_args.append(a)
new_kwargs = {}
for kwarg_idx, (k, v) in enumerate(node_to_update.kwargs.items()):
if isinstance(v, Node) and v.name == old.name:
new_kwargs[k] = new
else:
new_kwargs[k] = v
node_to_update.args = tuple(new_args)
node_to_update.kwargs = new_kwargs
args = node_to_update.args
def replace_arg(a):
if a == old:
return new
return a
# First, replace usages of "b" with "a"
node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args)
node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs)
# Second, update our storage_to_nodes data structure.
old_flattened_res, _ = tree_flatten(old.meta['fake_result'])
node_flattened_res, _ = tree_flatten(node_to_update.meta['fake_result'])