mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add should_traverse_fn to torch.fx.node.map_aggregate (#81510)
Adds an optional callback that checks if map_aggregate should continue recursive traversal. The main motivation is to not traverse torch.Size which is tuple Pull Request resolved: https://github.com/pytorch/pytorch/pull/81510 Approved by: https://github.com/SherlockNoMad, https://github.com/jamesr66a
This commit is contained in:
parent
7af0200a46
commit
d52f8c2533
|
|
@ -57,7 +57,7 @@ torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user
|
|||
torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node')
|
||||
torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
|
||||
torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None
|
||||
torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument]) -> torch.fx.node.Argument
|
||||
torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument], should_traverse_fn: Optional[Callable[[torch.fx.node.Argument], bool]] = None) -> torch.fx.node.Argument
|
||||
torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
|
||||
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None)
|
||||
torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
|
||||
|
|
|
|||
|
|
@ -3501,6 +3501,44 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
|||
self.assertEqual(gm(2, 3), 6)
|
||||
self.assertIn("a *= b", gm.code)
|
||||
|
||||
def test_map_aggregate_doesnt_traverse_size(self):
|
||||
def dont_traverse_size(a):
|
||||
return type(a) != torch.Size
|
||||
|
||||
size = torch.Size([1, 2, 3])
|
||||
|
||||
res = torch.fx.node.map_aggregate(size, lambda a: a)
|
||||
self.assertEqual(type(res), tuple)
|
||||
self.assertEqual(res, (1, 2, 3))
|
||||
|
||||
res = torch.fx.node.map_aggregate(size, lambda a: a, dont_traverse_size)
|
||||
self.assertEqual(type(res), torch.Size)
|
||||
self.assertEqual(res, size)
|
||||
|
||||
data = (torch.empty(3, 4), size,
|
||||
{'tensor': torch.empty(4, 5), 'size': size, 'list': [size, (size,), torch.empty(5, 6)]})
|
||||
|
||||
res = torch.fx.node.map_aggregate(data, lambda a: a)
|
||||
self.assertEqual(type(res[1]), tuple)
|
||||
self.assertEqual(res[1], (1, 2, 3))
|
||||
self.assertEqual(type(res[2]['size']), tuple)
|
||||
self.assertEqual(res[2]['size'], (1, 2, 3))
|
||||
self.assertEqual(type(res[2]['list'][0]), tuple)
|
||||
self.assertEqual(res[2]['list'][0], (1, 2, 3))
|
||||
self.assertEqual(type(res[2]['list'][1][0]), tuple)
|
||||
self.assertEqual(res[2]['list'][1][0], (1, 2, 3))
|
||||
|
||||
res = torch.fx.node.map_aggregate(data, lambda a: a, dont_traverse_size)
|
||||
self.assertEqual(type(res[1]), torch.Size)
|
||||
self.assertEqual(res[1], size)
|
||||
self.assertEqual(type(res[2]['size']), torch.Size)
|
||||
self.assertEqual(res[2]['size'], size)
|
||||
self.assertEqual(type(res[2]['list'][0]), torch.Size)
|
||||
self.assertEqual(res[2]['list'][0], size)
|
||||
self.assertEqual(type(res[2]['list'][1][0]), torch.Size)
|
||||
self.assertEqual(res[2]['list'][1][0], size)
|
||||
|
||||
|
||||
|
||||
def run_getitem_target():
|
||||
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
|
||||
|
|
|
|||
|
|
@ -600,20 +600,27 @@ def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
|
|||
assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
|
||||
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
|
||||
def map_aggregate(a: Argument, fn: Callable[[Argument], Argument],
|
||||
should_traverse_fn: Optional[Callable[[Argument], bool]] = None) -> Argument:
|
||||
"""
|
||||
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
|
||||
Traverses list, tuple, slice, or dict if ``should_traverse_fn`` is either None or returns True for supplied argument
|
||||
"""
|
||||
if should_traverse_fn and not should_traverse_fn(a):
|
||||
return fn(a)
|
||||
|
||||
if isinstance(a, tuple):
|
||||
t = tuple(map_aggregate(elem, fn) for elem in a)
|
||||
t = tuple(map_aggregate(elem, fn, should_traverse_fn) for elem in a)
|
||||
# Support NamedTuple (if it has `_fields`) by repacking into original type.
|
||||
return t if not hasattr(a, '_fields') else type(a)(*t)
|
||||
elif isinstance(a, list):
|
||||
return immutable_list(map_aggregate(elem, fn) for elem in a)
|
||||
return immutable_list(map_aggregate(elem, fn, should_traverse_fn) for elem in a)
|
||||
elif isinstance(a, dict):
|
||||
return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items())
|
||||
return immutable_dict((k, map_aggregate(v, fn, should_traverse_fn)) for k, v in a.items())
|
||||
elif isinstance(a, slice):
|
||||
return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn))
|
||||
return slice(map_aggregate(a.start, fn, should_traverse_fn), map_aggregate(a.stop, fn, should_traverse_fn),
|
||||
map_aggregate(a.step, fn, should_traverse_fn))
|
||||
else:
|
||||
return fn(a)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user