mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx] Improve support for tuple subclasses such as NamedTuple (#73198)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73198
Previously, if an arg to an FX node is a subclass of tuple then it gets sanitized essentially back to that base class. An example here is when setting an arg to be a TensorMetadata object, which is a NamedTuple, it will be set as a tuple instead.
- Change `map_aggregate` to repack the tuple to `type(a)` when it's not directly a tuple (try/except for best attempt)
- During codegen, call `add_global` for `type(a)` if it's not directly a tuple.
- Add an option for an arg to provide a `_custom_fx_repr_fn` for use inside stringifying via `_format_arg`
Test Plan: Added unit test coverage, where we inline the named tuple into arg/kwarg.
Reviewed By: jamesr66a
Differential Revision: D34381888
fbshipit-source-id: bd672a8542e2bba5aa604b448bec920efc256440
(cherry picked from commit 68f99c12dd)
This commit is contained in:
parent
715a0dc5c0
commit
987f146185
|
|
@ -25,7 +25,7 @@ from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate
|
|||
import torch.utils._pytree as pytree
|
||||
import torch.fx._pytree as fx_pytree
|
||||
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
|
||||
from torch.fx.node import Target, Argument
|
||||
from torch.fx.node import Target, Argument, _format_arg
|
||||
from torch.fx.passes import shape_prop
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
from torch.fx.experimental.rewriter import RewritingTracer
|
||||
|
|
@ -101,6 +101,11 @@ wrap('len')
|
|||
|
||||
wrap('getattr')
|
||||
|
||||
def wrapped_named_tup(p1, *, p2):
|
||||
return p1.x + p2.y
|
||||
|
||||
wrap(wrapped_named_tup)
|
||||
|
||||
@wrap
|
||||
def wrapped_via_decorator(a):
|
||||
return a + 1
|
||||
|
|
@ -125,6 +130,9 @@ class Pair(NamedTuple):
|
|||
x : torch.Tensor
|
||||
y : torch.Tensor
|
||||
|
||||
def _custom_fx_repr_fn(self) -> str:
|
||||
return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})"
|
||||
|
||||
# for testing pytrees
|
||||
class Foo(object): # noqa: B209
|
||||
def __init__(self, a, b):
|
||||
|
|
@ -2261,6 +2269,40 @@ class TestFX(JitTestCase):
|
|||
input = torch.rand(3, 4)
|
||||
self.assertEqual(traced(input), Pair(input, input))
|
||||
|
||||
def test_named_tuple_inlined(self):
|
||||
class NamedTupMod(torch.nn.Module):
|
||||
def forward(self, inp):
|
||||
return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp))
|
||||
|
||||
m = NamedTupMod()
|
||||
input = torch.rand(3, 4)
|
||||
ref = m(input)
|
||||
traced = symbolic_trace(m)
|
||||
|
||||
res = traced(input)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
# Check Pair NamedTuple works when inlined into the function call.
|
||||
ph = call_func = None
|
||||
for node in traced.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
ph = node
|
||||
elif node.op == "call_function" and node.target == wrapped_named_tup:
|
||||
node.update_arg(0, Pair(ph, 1.2))
|
||||
node.update_kwarg("p2", Pair(3.4, ph))
|
||||
call_func = node
|
||||
break
|
||||
self.assertTrue(call_func is not None)
|
||||
self.assertTrue(isinstance(call_func.args[0], Pair))
|
||||
self.assertTrue(isinstance(call_func.kwargs["p2"], Pair))
|
||||
self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)")
|
||||
self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)")
|
||||
|
||||
traced.graph.eliminate_dead_code()
|
||||
traced.recompile()
|
||||
res = traced(input)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_return_type_exists(self):
|
||||
class ReturnTypeModule(torch.nn.Module):
|
||||
def other(self, x: List[str]) -> List[str]:
|
||||
|
|
|
|||
|
|
@ -196,13 +196,6 @@ class PythonCode:
|
|||
globals: Dict[str, Any]
|
||||
|
||||
|
||||
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
|
||||
args_s = ', '.join(repr(a) for a in args)
|
||||
kwargs_s = ', '.join(f'{k} = {repr(v)}' for k, v in kwargs.items())
|
||||
if args_s and kwargs_s:
|
||||
return f'{args_s}, {kwargs_s}'
|
||||
return args_s or kwargs_s
|
||||
|
||||
def _format_target(base: str, target: str) -> str:
|
||||
elems = target.split('.')
|
||||
r = base
|
||||
|
|
@ -357,6 +350,20 @@ class CodeGen(object):
|
|||
# Common case: this is a regular module name like 'foo.bar.baz'
|
||||
return add_global(typename, o)
|
||||
|
||||
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
|
||||
def _get_repr(arg):
|
||||
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
|
||||
qualified_name = _get_qualified_name(type(arg))
|
||||
global_name = add_global(qualified_name, type(arg))
|
||||
return f"{global_name}{repr(tuple(arg))}"
|
||||
return repr(arg)
|
||||
args_s = ', '.join(_get_repr(a) for a in args)
|
||||
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
|
||||
if args_s and kwargs_s:
|
||||
return f'{args_s}, {kwargs_s}'
|
||||
return args_s or kwargs_s
|
||||
|
||||
# Run through reverse nodes and record the first instance of a use
|
||||
# of a given node. This represents the *last* use of the node in the
|
||||
# execution order of the program, which we will use to free unused
|
||||
|
|
|
|||
|
|
@ -69,7 +69,9 @@ def _get_qualified_name(func: Callable[..., Any]) -> str:
|
|||
return f'{module}.{name}'
|
||||
|
||||
def _format_arg(arg) -> str:
|
||||
if isinstance(arg, list):
|
||||
if hasattr(arg, "_custom_fx_repr_fn"):
|
||||
return arg._custom_fx_repr_fn()
|
||||
elif isinstance(arg, list):
|
||||
items = ', '.join(_format_arg(a) for a in arg)
|
||||
return f'[{items}]'
|
||||
elif isinstance(arg, tuple):
|
||||
|
|
@ -587,7 +589,9 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
|
|||
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
|
||||
"""
|
||||
if isinstance(a, tuple):
|
||||
return tuple(map_aggregate(elem, fn) for elem in a)
|
||||
t = tuple(map_aggregate(elem, fn) for elem in a)
|
||||
# Support NamedTuple (if it has `_fields`) by repacking into original type.
|
||||
return t if not hasattr(a, '_fields') else type(a)(*t)
|
||||
elif isinstance(a, list):
|
||||
return immutable_list(map_aggregate(elem, fn) for elem in a)
|
||||
elif isinstance(a, dict):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user