[fx] Improve support for tuple subclasses such as NamedTuple (#73198)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73198

Previously, if an arg to an FX node is a subclass of tuple then it gets sanitized essentially back to that base class. An example here is when setting an arg to be a TensorMetadata object, which is a NamedTuple, it will be set as a tuple instead.

- Change `map_aggregate` to repack the tuple to `type(a)` when it's not directly a tuple (try/except for best attempt)
- During codegen, call `add_global` for `type(a)` if it's not directly a tuple.
- Add an option for an arg to provide a `_custom_fx_repr_fn` for use inside stringifying via `_format_arg`

Test Plan: Added unit test coverage, where we inline the named tuple into arg/kwarg.

Reviewed By: jamesr66a

Differential Revision: D34381888

fbshipit-source-id: bd672a8542e2bba5aa604b448bec920efc256440
(cherry picked from commit 68f99c12dd)
This commit is contained in:
Jordan Fix 2022-02-23 02:38:29 -08:00 committed by PyTorch MergeBot
parent 715a0dc5c0
commit 987f146185
3 changed files with 63 additions and 10 deletions

View File

@ -25,7 +25,7 @@ from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate
import torch.utils._pytree as pytree
import torch.fx._pytree as fx_pytree
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
from torch.fx.node import Target, Argument
from torch.fx.node import Target, Argument, _format_arg
from torch.fx.passes import shape_prop
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.fx.experimental.rewriter import RewritingTracer
@ -101,6 +101,11 @@ wrap('len')
wrap('getattr')
def wrapped_named_tup(p1, *, p2):
return p1.x + p2.y
wrap(wrapped_named_tup)
@wrap
def wrapped_via_decorator(a):
return a + 1
@ -125,6 +130,9 @@ class Pair(NamedTuple):
x : torch.Tensor
y : torch.Tensor
def _custom_fx_repr_fn(self) -> str:
return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})"
# for testing pytrees
class Foo(object): # noqa: B209
def __init__(self, a, b):
@ -2261,6 +2269,40 @@ class TestFX(JitTestCase):
input = torch.rand(3, 4)
self.assertEqual(traced(input), Pair(input, input))
def test_named_tuple_inlined(self):
class NamedTupMod(torch.nn.Module):
def forward(self, inp):
return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp))
m = NamedTupMod()
input = torch.rand(3, 4)
ref = m(input)
traced = symbolic_trace(m)
res = traced(input)
self.assertEqual(ref, res)
# Check Pair NamedTuple works when inlined into the function call.
ph = call_func = None
for node in traced.graph.nodes:
if node.op == "placeholder":
ph = node
elif node.op == "call_function" and node.target == wrapped_named_tup:
node.update_arg(0, Pair(ph, 1.2))
node.update_kwarg("p2", Pair(3.4, ph))
call_func = node
break
self.assertTrue(call_func is not None)
self.assertTrue(isinstance(call_func.args[0], Pair))
self.assertTrue(isinstance(call_func.kwargs["p2"], Pair))
self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)")
self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)")
traced.graph.eliminate_dead_code()
traced.recompile()
res = traced(input)
self.assertEqual(ref, res)
def test_return_type_exists(self):
class ReturnTypeModule(torch.nn.Module):
def other(self, x: List[str]) -> List[str]:

View File

@ -196,13 +196,6 @@ class PythonCode:
globals: Dict[str, Any]
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
args_s = ', '.join(repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {repr(v)}' for k, v in kwargs.items())
if args_s and kwargs_s:
return f'{args_s}, {kwargs_s}'
return args_s or kwargs_s
def _format_target(base: str, target: str) -> str:
elems = target.split('.')
r = base
@ -357,6 +350,20 @@ class CodeGen(object):
# Common case: this is a regular module name like 'foo.bar.baz'
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
args_s = ', '.join(_get_repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
if args_s and kwargs_s:
return f'{args_s}, {kwargs_s}'
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused

View File

@ -69,7 +69,9 @@ def _get_qualified_name(func: Callable[..., Any]) -> str:
return f'{module}.{name}'
def _format_arg(arg) -> str:
if isinstance(arg, list):
if hasattr(arg, "_custom_fx_repr_fn"):
return arg._custom_fx_repr_fn()
elif isinstance(arg, list):
items = ', '.join(_format_arg(a) for a in arg)
return f'[{items}]'
elif isinstance(arg, tuple):
@ -587,7 +589,9 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
"""
if isinstance(a, tuple):
return tuple(map_aggregate(elem, fn) for elem in a)
t = tuple(map_aggregate(elem, fn) for elem in a)
# Support NamedTuple (if it has `_fields`) by repacking into original type.
return t if not hasattr(a, '_fields') else type(a)(*t)
elif isinstance(a, list):
return immutable_list(map_aggregate(elem, fn) for elem in a)
elif isinstance(a, dict):