diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index bd8c0e63a52..c854eef93ad 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -57,7 +57,7 @@ torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None -torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument]) -> torch.fx.node.Argument +torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument], should_traverse_fn: Optional[Callable[[torch.fx.node.Argument], bool]] = None) -> torch.fx.node.Argument torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None) torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str) diff --git a/test/test_fx.py b/test/test_fx.py index a7d1ca1f150..a713ce5245e 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -3501,6 +3501,44 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: self.assertEqual(gm(2, 3), 6) self.assertIn("a *= b", gm.code) + def test_map_aggregate_doesnt_traverse_size(self): + def dont_traverse_size(a): + return type(a) != torch.Size + + size = torch.Size([1, 2, 3]) + + res = torch.fx.node.map_aggregate(size, lambda a: a) + self.assertEqual(type(res), tuple) + self.assertEqual(res, (1, 2, 3)) + + res = torch.fx.node.map_aggregate(size, lambda a: a, dont_traverse_size) + self.assertEqual(type(res), torch.Size) + self.assertEqual(res, size) + + data = (torch.empty(3, 4), size, + {'tensor': torch.empty(4, 5), 'size': size, 'list': [size, (size,), torch.empty(5, 6)]}) + + res = torch.fx.node.map_aggregate(data, lambda a: a) + self.assertEqual(type(res[1]), tuple) + self.assertEqual(res[1], (1, 2, 3)) + self.assertEqual(type(res[2]['size']), tuple) + self.assertEqual(res[2]['size'], (1, 2, 3)) + self.assertEqual(type(res[2]['list'][0]), tuple) + self.assertEqual(res[2]['list'][0], (1, 2, 3)) + self.assertEqual(type(res[2]['list'][1][0]), tuple) + self.assertEqual(res[2]['list'][1][0], (1, 2, 3)) + + res = torch.fx.node.map_aggregate(data, lambda a: a, dont_traverse_size) + self.assertEqual(type(res[1]), torch.Size) + self.assertEqual(res[1], size) + self.assertEqual(type(res[2]['size']), torch.Size) + self.assertEqual(res[2]['size'], size) + self.assertEqual(type(res[2]['list'][0]), torch.Size) + self.assertEqual(res[2]['list'][0], size) + self.assertEqual(type(res[2]['list'][1][0]), torch.Size) + self.assertEqual(res[2]['list'][1][0], size) + + def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch diff --git a/torch/fx/node.py b/torch/fx/node.py index 66a94154f94..8ddfbfa6575 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -600,20 +600,27 @@ def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) + @compatibility(is_backward_compatible=True) -def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: +def map_aggregate(a: Argument, fn: Callable[[Argument], Argument], + should_traverse_fn: Optional[Callable[[Argument], bool]] = None) -> Argument: """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + Traverses list, tuple, slice, or dict if ``should_traverse_fn`` is either None or returns True for supplied argument """ + if should_traverse_fn and not should_traverse_fn(a): + return fn(a) + if isinstance(a, tuple): - t = tuple(map_aggregate(elem, fn) for elem in a) + t = tuple(map_aggregate(elem, fn, should_traverse_fn) for elem in a) # Support NamedTuple (if it has `_fields`) by repacking into original type. return t if not hasattr(a, '_fields') else type(a)(*t) elif isinstance(a, list): - return immutable_list(map_aggregate(elem, fn) for elem in a) + return immutable_list(map_aggregate(elem, fn, should_traverse_fn) for elem in a) elif isinstance(a, dict): - return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items()) + return immutable_dict((k, map_aggregate(v, fn, should_traverse_fn)) for k, v in a.items()) elif isinstance(a, slice): - return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) + return slice(map_aggregate(a.start, fn, should_traverse_fn), map_aggregate(a.stop, fn, should_traverse_fn), + map_aggregate(a.step, fn, should_traverse_fn)) else: return fn(a)