mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
reinplace pass: bugfix for output node replacement (#83845)
Cleaned up some of the arg replacement logic to use tree_map, so it handles FX nodes that have nested containers. See the added test: when you write a function that returns a list, the `output` node in the FX graph shows up as having `node.args = tuple(immutable_list(...))` Pull Request resolved: https://github.com/pytorch/pytorch/pull/83845 Approved by: https://github.com/ezyang
This commit is contained in:
parent
01434c2d20
commit
75ec7b7547
|
|
@ -295,5 +295,32 @@ def forward(self, a__1):
|
|||
return select_scatter_default
|
||||
""") # noqa: B950
|
||||
|
||||
|
||||
def test_out_node_updated(self):
|
||||
def f():
|
||||
x = torch.zeros(2, 2)
|
||||
y = x.diagonal()
|
||||
y_updated = y.add(1)
|
||||
z = torch.diagonal_scatter(x, y_updated)
|
||||
# reinplace needs to know to replace output [z] with [x]
|
||||
return [z]
|
||||
|
||||
if not HAS_FUNCTIONALIZATION:
|
||||
return
|
||||
f2 = reinplace(make_fx(functionalize(f))())
|
||||
expected_out = f()
|
||||
actual_out = f2()
|
||||
self.assertEqual(actual_out, expected_out)
|
||||
self.assertExpectedInline(f2.code, """\
|
||||
|
||||
|
||||
|
||||
def forward(self):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_default = torch.ops.aten.diagonal.default(zeros)
|
||||
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, 1); diagonal_default = None
|
||||
return [zeros]
|
||||
""")
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch.fx import Node
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
|
||||
from torch.utils._pytree import tree_map, tree_flatten
|
||||
from torch.utils._pytree import tree_map, tree_flatten, tree_map_only
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
||||
import _operator
|
||||
|
|
@ -526,20 +526,18 @@ def reinplace(gm, *sample_args):
|
|||
nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
|
||||
for node_to_update in nodes_to_update:
|
||||
new_args = []
|
||||
for arg_idx, a in enumerate(node_to_update.args):
|
||||
if a == old:
|
||||
new_args.append(new)
|
||||
else:
|
||||
new_args.append(a)
|
||||
new_kwargs = {}
|
||||
for kwarg_idx, (k, v) in enumerate(node_to_update.kwargs.items()):
|
||||
if isinstance(v, Node) and v.name == old.name:
|
||||
new_kwargs[k] = new
|
||||
else:
|
||||
new_kwargs[k] = v
|
||||
node_to_update.args = tuple(new_args)
|
||||
node_to_update.kwargs = new_kwargs
|
||||
args = node_to_update.args
|
||||
|
||||
def replace_arg(a):
|
||||
if a == old:
|
||||
return new
|
||||
return a
|
||||
|
||||
# First, replace usages of "b" with "a"
|
||||
node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args)
|
||||
node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs)
|
||||
|
||||
# Second, update our storage_to_nodes data structure.
|
||||
old_flattened_res, _ = tree_flatten(old.meta['fake_result'])
|
||||
node_flattened_res, _ = tree_flatten(node_to_update.meta['fake_result'])
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user