mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[graph_manipulation] Unpack list of outputs (#72940)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72940 att Reviewed By: jackm321 Differential Revision: D34282062 fbshipit-source-id: 743710c18e1f38286d1b91c91868bb22c760f3ca
This commit is contained in:
parent
fb9e92fea5
commit
fd2bdd189d
|
|
@ -451,9 +451,9 @@ def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> D
|
|||
get_output_arg_info,
|
||||
)
|
||||
|
||||
# If there're multiple outputs then node_rep["args"][0] will be a tuple.
|
||||
# In this case we want to unpack the tuple.
|
||||
if isinstance(node_rep["args"][0], tuple):
|
||||
# If there're multiple outputs then node_rep["args"][0] will be a tuple or
|
||||
# list. In this case we want to unpack the tuple or list.
|
||||
if isinstance(node_rep["args"][0], (tuple, list)):
|
||||
node_rep["args"] = node_rep["args"][0]
|
||||
else:
|
||||
node_rep["args"] = map_aggregate(node.args, get_arg_info)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user