Add should_traverse_fn to torch.fx.node.map_aggregate (#81510)

Adds an optional callback that checks if map_aggregate should continue recursive traversal. The main motivation is to not traverse torch.Size which is tuple

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81510
Approved by: https://github.com/SherlockNoMad, https://github.com/jamesr66a
This commit is contained in:
Pavel Belevich 2022-07-14 23:21:02 -04:00 committed by PyTorch MergeBot
parent 7af0200a46
commit d52f8c2533
3 changed files with 51 additions and 6 deletions

View File

@ -57,7 +57,7 @@ torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user
torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node')
torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None
torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument]) -> torch.fx.node.Argument torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument], should_traverse_fn: Optional[Callable[[torch.fx.node.Argument], bool]] = None) -> torch.fx.node.Argument
torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None) torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None)
torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str) torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)

View File

@ -3501,6 +3501,44 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
self.assertEqual(gm(2, 3), 6) self.assertEqual(gm(2, 3), 6)
self.assertIn("a *= b", gm.code) self.assertIn("a *= b", gm.code)
def test_map_aggregate_doesnt_traverse_size(self):
def dont_traverse_size(a):
return type(a) != torch.Size
size = torch.Size([1, 2, 3])
res = torch.fx.node.map_aggregate(size, lambda a: a)
self.assertEqual(type(res), tuple)
self.assertEqual(res, (1, 2, 3))
res = torch.fx.node.map_aggregate(size, lambda a: a, dont_traverse_size)
self.assertEqual(type(res), torch.Size)
self.assertEqual(res, size)
data = (torch.empty(3, 4), size,
{'tensor': torch.empty(4, 5), 'size': size, 'list': [size, (size,), torch.empty(5, 6)]})
res = torch.fx.node.map_aggregate(data, lambda a: a)
self.assertEqual(type(res[1]), tuple)
self.assertEqual(res[1], (1, 2, 3))
self.assertEqual(type(res[2]['size']), tuple)
self.assertEqual(res[2]['size'], (1, 2, 3))
self.assertEqual(type(res[2]['list'][0]), tuple)
self.assertEqual(res[2]['list'][0], (1, 2, 3))
self.assertEqual(type(res[2]['list'][1][0]), tuple)
self.assertEqual(res[2]['list'][1][0], (1, 2, 3))
res = torch.fx.node.map_aggregate(data, lambda a: a, dont_traverse_size)
self.assertEqual(type(res[1]), torch.Size)
self.assertEqual(res[1], size)
self.assertEqual(type(res[2]['size']), torch.Size)
self.assertEqual(res[2]['size'], size)
self.assertEqual(type(res[2]['list'][0]), torch.Size)
self.assertEqual(res[2]['list'][0], size)
self.assertEqual(type(res[2]['list'][1][0]), torch.Size)
self.assertEqual(res[2]['list'][1][0], size)
def run_getitem_target(): def run_getitem_target():
from torch.fx._symbolic_trace import _wrapped_methods_to_patch from torch.fx._symbolic_trace import _wrapped_methods_to_patch

View File

@ -600,20 +600,27 @@ def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: def map_aggregate(a: Argument, fn: Callable[[Argument], Argument],
should_traverse_fn: Optional[Callable[[Argument], bool]] = None) -> Argument:
""" """
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
Traverses list, tuple, slice, or dict if ``should_traverse_fn`` is either None or returns True for supplied argument
""" """
if should_traverse_fn and not should_traverse_fn(a):
return fn(a)
if isinstance(a, tuple): if isinstance(a, tuple):
t = tuple(map_aggregate(elem, fn) for elem in a) t = tuple(map_aggregate(elem, fn, should_traverse_fn) for elem in a)
# Support NamedTuple (if it has `_fields`) by repacking into original type. # Support NamedTuple (if it has `_fields`) by repacking into original type.
return t if not hasattr(a, '_fields') else type(a)(*t) return t if not hasattr(a, '_fields') else type(a)(*t)
elif isinstance(a, list): elif isinstance(a, list):
return immutable_list(map_aggregate(elem, fn) for elem in a) return immutable_list(map_aggregate(elem, fn, should_traverse_fn) for elem in a)
elif isinstance(a, dict): elif isinstance(a, dict):
return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items()) return immutable_dict((k, map_aggregate(v, fn, should_traverse_fn)) for k, v in a.items())
elif isinstance(a, slice): elif isinstance(a, slice):
return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) return slice(map_aggregate(a.start, fn, should_traverse_fn), map_aggregate(a.stop, fn, should_traverse_fn),
map_aggregate(a.step, fn, should_traverse_fn))
else: else:
return fn(a) return fn(a)