mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[ts-migration][1/N]: Add prim::Loop for constant number of iterations and condition (#131418)
#### Description This PR adds prim::Loop support for the simplest case where the number of iteration is constant and the loop termination condition is also a constant. [PR by stages](https://docs.google.com/document/d/1q6OprW3HBHbYPwEyE_DikBn-uzmhnN284Cmen_CnlhI/edit?usp=sharing) #### Test Plan Add reprod example. * `pytest test/export/test_converter.py -s -k test_ts2ep_with_loop` Pull Request resolved: https://github.com/pytorch/pytorch/pull/131418 Approved by: https://github.com/angelayi
This commit is contained in:
parent
c803e35c4b
commit
ca7ce2fca1
|
|
@ -102,10 +102,7 @@ class TestConverter(TestCase):
|
|||
raise RuntimeError(f"Unrecognized mode for torch.jit: {opt}")
|
||||
|
||||
converter = TS2EPConverter(ts_model, inp)
|
||||
print(opt, converter.ts_graph)
|
||||
|
||||
ep = converter.convert()
|
||||
print(ep)
|
||||
ep_list.append(ep)
|
||||
|
||||
for _ in range(num_iterations):
|
||||
|
|
@ -1014,7 +1011,6 @@ class TestConverter(TestCase):
|
|||
|
||||
ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),))
|
||||
for ep in ep_list:
|
||||
print(ep.constants)
|
||||
self.assertEqual(len(ep.constants), 1)
|
||||
|
||||
def test_aten_tensor_prim_dtype(self):
|
||||
|
|
@ -1327,6 +1323,45 @@ class TestConverter(TestCase):
|
|||
inp = (torch.randn(2, 3),)
|
||||
self._check_equal_ts_ep_converter(M1(), inp, ["script"])
|
||||
|
||||
def test_ts2ep_with_loop(self):
|
||||
def func1(x, x_list: List[torch.Tensor]):
|
||||
a, b, c = x, x, x
|
||||
for i in range(1, 5, 2):
|
||||
for k in range(5):
|
||||
a = a + a + k
|
||||
b = b + b - k
|
||||
x_list.append(x_list[k] + x_list[k + 1])
|
||||
for k in range(5):
|
||||
b = b + b - k
|
||||
c = c + c * k
|
||||
x_list.append(x_list[k] + x_list[k + 1] - x_list[k + 2])
|
||||
return x, x_list
|
||||
|
||||
def func2(x):
|
||||
for i in range(x.size(0)):
|
||||
x = x * x * i
|
||||
return x
|
||||
|
||||
def func3(x):
|
||||
while x.sum() < 10:
|
||||
x += x.sin()
|
||||
return x
|
||||
|
||||
inp = (
|
||||
torch.tensor(1),
|
||||
[torch.ones([2, 2]), torch.ones([2, 2]) * 2],
|
||||
)
|
||||
# Trace unrolls the loop.
|
||||
self._check_equal_ts_ep_converter(func1, inp, ["script"])
|
||||
|
||||
# TODO: (2/N)
|
||||
# Trace unrolls the loop.
|
||||
# self._check_equal_ts_ep_converter(func2, inp, ["script"])
|
||||
|
||||
# TODO: (3/N)
|
||||
# Trace unrolls the loop.
|
||||
# self._check_equal_ts_ep_converter(func3, inp, ["script"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -124,6 +124,26 @@ def list_append(container, element):
|
|||
return container + [element]
|
||||
|
||||
|
||||
def execute_subgraph_from_prim_loop(
|
||||
subgraph, iter_idx, len_loop_local_arguments, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
subgraph: GraphModule from sub-block.
|
||||
iter_idx: The index of interation.
|
||||
len_loop_local_arguments: The number of loop local arguments in args.
|
||||
"""
|
||||
|
||||
# Loop local variables. TS graph create those as inputs because their values
|
||||
# are updated inside the loop.
|
||||
loop_local_args = args[:len_loop_local_arguments]
|
||||
# Global variables that are not passed in as inputs to the loop sub-blocks
|
||||
# but are directly used. Most of time, their values are not updated, but
|
||||
# the only exception is when there are some operations that perform inplace
|
||||
# updates.
|
||||
global_args = args[len_loop_local_arguments:]
|
||||
return subgraph(*global_args, iter_idx, *loop_local_args, **kwargs)
|
||||
|
||||
|
||||
def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule):
|
||||
def pattern(im, dim, scale):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(im, dim)
|
||||
|
|
@ -415,6 +435,13 @@ class TS2FXGraphConverter:
|
|||
lambda node: self._convert_standard_operators(node),
|
||||
)
|
||||
|
||||
# This stores a list of return results that do not appear in the original TS
|
||||
# graph's outputs. The reason we maintain this is because some operations in the sub-block
|
||||
# might have inplace updates to the variable defined in the parent fx graph. After
|
||||
# the execution of that sub-block, the variable defined in the parent fx graph also
|
||||
# needs to be updated.
|
||||
self.name_update_from_subblock_to_parent: Set[str] = set()
|
||||
|
||||
def _is_get_attr_node(self, fqn):
|
||||
return (
|
||||
fqn in self.name_to_buffer
|
||||
|
|
@ -425,6 +452,57 @@ class TS2FXGraphConverter:
|
|||
)
|
||||
)
|
||||
|
||||
def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]):
|
||||
subgraph_nodes, subgraph_converters = [], []
|
||||
for block in node.blocks():
|
||||
subgraph_converter = TS2FXGraphConverter(
|
||||
block,
|
||||
self.name_to_param,
|
||||
self.name_to_buffer,
|
||||
self.blocks_to_lifted_attrs,
|
||||
{},
|
||||
self.name_to_constant,
|
||||
)
|
||||
subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn
|
||||
|
||||
for block_arg in arguments:
|
||||
normalized_block_arg_name = normalize_name(block_arg)
|
||||
placeholder_node = subgraph_converter.fx_graph.placeholder(
|
||||
normalized_block_arg_name
|
||||
)
|
||||
subgraph_converter.name_to_node[block_arg] = placeholder_node
|
||||
|
||||
subgraph = subgraph_converter.convert()
|
||||
subgraph_name = self.add_subgraph(subgraph)
|
||||
subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name))
|
||||
subgraph_converters.append(subgraph_converter)
|
||||
return subgraph_nodes, subgraph_converters
|
||||
|
||||
def _identify_inputs_as_arguments(self, entry):
|
||||
"""
|
||||
Identify inputs from the innermost sub-block. This is needed
|
||||
for nested sub-blocks when the input is hidden in the nested sub-block.
|
||||
E.g., example IR of input is hidden in the nested sub-block.
|
||||
Graph[x.1]
|
||||
%1 = ...
|
||||
Block[]
|
||||
Block[x.1]
|
||||
%2 = x.1 ...
|
||||
"""
|
||||
arguments: Set[str] = set()
|
||||
for block in entry.blocks():
|
||||
for block_node in block.nodes():
|
||||
for block_node_in in block_node.inputs():
|
||||
if (
|
||||
block_node_in.debugName() in self.name_to_node
|
||||
and block_node_in.debugName() not in self.name_to_attribute_fqn
|
||||
):
|
||||
arguments.add(block_node_in.debugName())
|
||||
arguments = arguments.union(
|
||||
self._identify_inputs_as_arguments(block_node)
|
||||
)
|
||||
return arguments
|
||||
|
||||
def is_top_level_graph(self):
|
||||
return isinstance(self.ts_graph, torch._C.Graph)
|
||||
|
||||
|
|
@ -438,13 +516,13 @@ class TS2FXGraphConverter:
|
|||
kwargs = {}
|
||||
for input, schema_arg in zip(node.inputs(), schema.arguments):
|
||||
if schema_arg.kwarg_only:
|
||||
kwargs[schema_arg.name] = self.get_fx_value(input)
|
||||
kwargs[schema_arg.name] = self.get_fx_value_by_ir_value(input)
|
||||
else:
|
||||
args.append(self.get_fx_value(input))
|
||||
args.append(self.get_fx_value_by_ir_value(input))
|
||||
|
||||
return tuple(args), kwargs
|
||||
|
||||
def get_fx_value(self, value: torch._C.Value):
|
||||
def get_fx_value_by_ir_value(self, value: torch._C.Value):
|
||||
value_name = value.debugName()
|
||||
|
||||
if value_name in self.name_to_node:
|
||||
|
|
@ -457,6 +535,19 @@ class TS2FXGraphConverter:
|
|||
else:
|
||||
raise ValueError(f"Input {value_name} not found")
|
||||
|
||||
def get_fx_value_by_fqn(self, name):
|
||||
if name in self.name_to_node:
|
||||
fx_node = self.name_to_node[name]
|
||||
elif name in self.name_to_constant:
|
||||
fx_node = self.name_to_constant[name]
|
||||
elif name in self.name_to_non_tensor_attribute_node:
|
||||
fx_node = self.name_to_non_tensor_attribute_node[name]
|
||||
elif name in self.name_to_non_tensor_attribute:
|
||||
fx_node = self.name_to_non_tensor_attribute[name]
|
||||
else:
|
||||
raise ValueError(f"Attribute {name} not found")
|
||||
return fx_node
|
||||
|
||||
def convert(self) -> torch.fx.GraphModule:
|
||||
self.convert_graph_inputs()
|
||||
|
||||
|
|
@ -585,13 +676,17 @@ class TS2FXGraphConverter:
|
|||
"This makes the converter non-functional: the result depends on the order of the append nodes being converter!"
|
||||
)
|
||||
|
||||
args = tuple(self.get_fx_value(inp) for inp in node.inputs())
|
||||
args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs())
|
||||
fx_node = self.fx_graph.call_function(list_append, args)
|
||||
self.name_to_node[node.output().debugName()] = fx_node
|
||||
|
||||
# inplace mutate arg[0], which is the python list
|
||||
self.name_to_node[node.inputsAt(0).debugName()] = fx_node
|
||||
|
||||
# Variables that need to be updated to parent module.
|
||||
if not self.is_top_level_graph() and args[0].op == "placeholder":
|
||||
self.name_update_from_subblock_to_parent.add(node.inputsAt(0).debugName())
|
||||
|
||||
def convert_prim_Constant(self, node: torch._C.Node):
|
||||
name = node.output().debugName()
|
||||
|
||||
|
|
@ -621,7 +716,9 @@ class TS2FXGraphConverter:
|
|||
self.name_to_constant[name] = value
|
||||
|
||||
def convert_prim_CallMethod(self, node: torch._C.Node):
|
||||
inp_list = [self.get_fx_value(inp) for inp in node.inputs()] # noqa: C416
|
||||
inp_list = [
|
||||
self.get_fx_value_by_ir_value(inp) for inp in node.inputs()
|
||||
] # noqa: C416
|
||||
fx_node = self.fx_graph.call_method(
|
||||
node.s("name"),
|
||||
tuple(inp_list),
|
||||
|
|
@ -668,7 +765,7 @@ class TS2FXGraphConverter:
|
|||
def convert_prim_SetAttr(self, node: torch._C.Node):
|
||||
attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node)
|
||||
attr_value = tuple(node.inputs())[1]
|
||||
ts_graph_tensor_input = self.get_fx_value(attr_value)
|
||||
ts_graph_tensor_input = self.get_fx_value_by_ir_value(attr_value)
|
||||
if self._is_get_attr_node(attr_fqn):
|
||||
fx_attr_node = self.fx_graph.get_attr(attr_fqn)
|
||||
self.fx_graph.call_function(
|
||||
|
|
@ -707,7 +804,7 @@ class TS2FXGraphConverter:
|
|||
def _convert_prim_iterator(self, node: torch._C.Node):
|
||||
output_list = []
|
||||
for inp in node.inputs():
|
||||
output_list.append(self.get_fx_value(inp))
|
||||
output_list.append(self.get_fx_value_by_ir_value(inp))
|
||||
|
||||
output_name = node.output().debugName()
|
||||
self.name_to_node[output_name] = output_list
|
||||
|
|
@ -719,9 +816,9 @@ class TS2FXGraphConverter:
|
|||
# We assume key value are stored in pair in the DictConstruct.
|
||||
# The first element is the key and the following is the value.
|
||||
if i % 2 == 0:
|
||||
k = self.get_fx_value(inp)
|
||||
k = self.get_fx_value_by_ir_value(inp)
|
||||
else:
|
||||
v = self.get_fx_value(inp)
|
||||
v = self.get_fx_value_by_ir_value(inp)
|
||||
assert (
|
||||
k is not None and v is not None
|
||||
), "DictConstruct has an empty key value pair."
|
||||
|
|
@ -745,14 +842,14 @@ class TS2FXGraphConverter:
|
|||
# Single input and multiple outputs for unpacking.
|
||||
for i, outp in enumerate(node.outputs()):
|
||||
outp_name = outp.debugName()
|
||||
inp = self.get_fx_value(node.input())
|
||||
inp = self.get_fx_value_by_ir_value(node.input())
|
||||
fx_node = self.fx_graph.call_function(operator.getitem, (inp, i))
|
||||
self.name_to_node[outp_name] = fx_node
|
||||
|
||||
def convert_aten_Int(self, node: torch._C.Node):
|
||||
# converts aten::Int as aten._to_copy + aten::_local_scalar_dense
|
||||
target = torch.ops.aten._to_copy.default
|
||||
args = tuple(self.get_fx_value(input) for input in node.inputs())
|
||||
args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
|
||||
to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32})
|
||||
|
||||
fx_node = self.fx_graph.call_function(
|
||||
|
|
@ -773,7 +870,7 @@ class TS2FXGraphConverter:
|
|||
# For both of those APIs, torch.jit.trace implicitly sets the output tensor type
|
||||
# to be LongTensor.
|
||||
target = torch.ops.aten.scalar_tensor
|
||||
args = tuple(self.get_fx_value(input) for input in node.inputs())
|
||||
args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
|
||||
|
||||
fx_node = self.fx_graph.call_function(target, args, {"dtype": torch.long})
|
||||
output_name = node.output().debugName()
|
||||
|
|
@ -829,7 +926,7 @@ class TS2FXGraphConverter:
|
|||
|
||||
def convert_aten___getitem__(self, node: torch._C.Node):
|
||||
input_container, index = tuple(
|
||||
self.get_fx_value(input) for input in node.inputs()
|
||||
self.get_fx_value_by_ir_value(input) for input in node.inputs()
|
||||
)
|
||||
fx_node = self.fx_graph.call_function(
|
||||
operator.getitem, (input_container, index)
|
||||
|
|
@ -895,6 +992,111 @@ class TS2FXGraphConverter:
|
|||
else:
|
||||
self.convert_call_function_op(node)
|
||||
|
||||
def _check_prim_loop_support(self, node):
|
||||
inputs = list(node.inputs())
|
||||
|
||||
# TODO: (1/N) stage.
|
||||
if inputs[0].debugName() not in self.name_to_constant:
|
||||
raise RuntimeError(
|
||||
"prim::Loop currently cannot run with dynamic value of number of iterations."
|
||||
)
|
||||
|
||||
# Make sure the condition is not updated in the subblock.
|
||||
subblock = next(node.blocks())
|
||||
condition_output_name = next(subblock.outputs()).debugName()
|
||||
for node in subblock.nodes():
|
||||
if (
|
||||
node.outputsSize() == 1
|
||||
and node.output().debugName() == condition_output_name
|
||||
):
|
||||
raise RuntimeError(
|
||||
"prim::Loop currently cannot run with dynamic value of condition."
|
||||
)
|
||||
if node.outputsSize() >= 2:
|
||||
for outp in node.outputs():
|
||||
if outp.debugName() == condition_output_name:
|
||||
raise RuntimeError(
|
||||
"prim::Loop currently cannot run with dynamic value of condition."
|
||||
)
|
||||
|
||||
def convert_prim_Loop(self, node: torch._C.Node):
|
||||
inputs = list(node.inputs())
|
||||
self._check_prim_loop_support(node)
|
||||
|
||||
num_iterations = self.get_fx_value_by_ir_value(inputs[0])
|
||||
|
||||
# Find inputs.
|
||||
loop_local_arguments = [inp.debugName() for inp in inputs[2:]]
|
||||
|
||||
global_arguments = self._identify_inputs_as_arguments(node)
|
||||
|
||||
# Lift parameters as inputs.
|
||||
for block in node.blocks():
|
||||
global_arguments = global_arguments.union(
|
||||
self.blocks_to_lifted_attrs[block]
|
||||
)
|
||||
|
||||
global_arguments = list(global_arguments)
|
||||
|
||||
subgraph_nodes, subgraph_converters = self._convert_block_to_subgraph(
|
||||
node, global_arguments
|
||||
)
|
||||
|
||||
assert len(subgraph_nodes) == 1
|
||||
subgraph_converter = subgraph_converters[0]
|
||||
if not self.is_top_level_graph():
|
||||
self.name_update_from_subblock_to_parent = (
|
||||
self.name_update_from_subblock_to_parent.union(
|
||||
subgraph_converter.name_update_from_subblock_to_parent
|
||||
)
|
||||
)
|
||||
|
||||
fx_block_args = [
|
||||
self.get_fx_value_by_fqn(name)
|
||||
for name in loop_local_arguments + global_arguments
|
||||
]
|
||||
for iter_idx in range(num_iterations):
|
||||
loop_node = self.fx_graph.call_function(
|
||||
execute_subgraph_from_prim_loop,
|
||||
# Check execute_node function for the expected arguments order.
|
||||
(
|
||||
subgraph_nodes[0],
|
||||
iter_idx,
|
||||
len(loop_local_arguments),
|
||||
*fx_block_args,
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
# Update the value of loop local variables.
|
||||
if node.outputsSize() >= 1:
|
||||
for i, outp in enumerate(node.outputs()):
|
||||
output_name = outp.debugName()
|
||||
self.name_to_node[output_name] = self.fx_graph.call_function(
|
||||
operator.getitem,
|
||||
(
|
||||
loop_node,
|
||||
i + 1,
|
||||
), # + 1 because the 0th element is the condition.
|
||||
)
|
||||
fx_block_args[i] = self.name_to_node[output_name]
|
||||
|
||||
# Update the value of global variables, whose values are modified inplace.
|
||||
for i, name in enumerate(
|
||||
subgraph_converter.name_update_from_subblock_to_parent
|
||||
):
|
||||
self.name_to_node[name] = self.fx_graph.call_function(
|
||||
operator.getitem,
|
||||
(
|
||||
loop_node,
|
||||
i + node.outputsSize() + 1,
|
||||
), # + 1 because the 0th element is the condition.
|
||||
)
|
||||
global_argument_index = global_arguments.index(name)
|
||||
fx_block_args[
|
||||
i + node.outputsSize() + global_argument_index
|
||||
] = self.name_to_node[name]
|
||||
|
||||
def _check_set_attr_in_if_block(self, if_node: torch._C.Node):
|
||||
for block in if_node.blocks():
|
||||
for node in block.nodes():
|
||||
|
|
@ -910,82 +1112,21 @@ class TS2FXGraphConverter:
|
|||
|
||||
inputs = list(node.inputs())
|
||||
assert len(inputs) == 1
|
||||
predicate = self.get_fx_value(inputs[0])
|
||||
|
||||
def _identify_inputs_as_arguments(entry):
|
||||
"""
|
||||
Identify inputs from the innermost sub-block. This is needed
|
||||
for nested sub-blocks when the input is hidden in the nested sub-block.
|
||||
E.g., example IR of input is hidden in the nested sub-block.
|
||||
Graph[x.1]
|
||||
%1 = ...
|
||||
Block[]
|
||||
Block[x.1]
|
||||
%2 = x.1 ...
|
||||
"""
|
||||
arguments: Set[str] = set()
|
||||
for block in entry.blocks():
|
||||
for block_node in block.nodes():
|
||||
for block_node_in in block_node.inputs():
|
||||
if (
|
||||
block_node_in.debugName() in self.name_to_node
|
||||
and block_node_in.debugName()
|
||||
not in self.name_to_attribute_fqn
|
||||
):
|
||||
arguments.add(block_node_in.debugName())
|
||||
arguments = arguments.union(
|
||||
_identify_inputs_as_arguments(block_node)
|
||||
)
|
||||
return arguments
|
||||
predicate = self.get_fx_value_by_ir_value(inputs[0])
|
||||
|
||||
# Find inputs.
|
||||
arguments = _identify_inputs_as_arguments(node)
|
||||
arguments = self._identify_inputs_as_arguments(node)
|
||||
|
||||
# Lift parameters as inputs.
|
||||
for block in node.blocks():
|
||||
arguments = arguments.union(self.blocks_to_lifted_attrs[block])
|
||||
|
||||
arguments = list(arguments)
|
||||
|
||||
# Convert blocks to subgraphs
|
||||
subgraph_nodes = []
|
||||
for block in node.blocks():
|
||||
subgraph_converter = TS2FXGraphConverter(
|
||||
block,
|
||||
self.name_to_param,
|
||||
self.name_to_buffer,
|
||||
self.blocks_to_lifted_attrs,
|
||||
{},
|
||||
self.name_to_constant,
|
||||
)
|
||||
subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn
|
||||
|
||||
for block_arg in arguments:
|
||||
normalized_block_arg_name = normalize_name(block_arg)
|
||||
placeholder_node = subgraph_converter.fx_graph.placeholder(
|
||||
normalized_block_arg_name
|
||||
)
|
||||
subgraph_converter.name_to_node[block_arg] = placeholder_node
|
||||
|
||||
subgraph = subgraph_converter.convert()
|
||||
subgraph_name = self.add_subgraph(subgraph)
|
||||
subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name))
|
||||
subgraph_nodes, _ = self._convert_block_to_subgraph(node, arguments)
|
||||
|
||||
assert len(subgraph_nodes) == 2
|
||||
|
||||
fx_block_args = []
|
||||
for arg_name in arguments:
|
||||
if arg_name in self.name_to_node:
|
||||
arg_node = self.name_to_node[arg_name]
|
||||
fx_block_args.append(arg_node)
|
||||
elif arg_name in self.name_to_non_tensor_attribute_node:
|
||||
arg_node = self.name_to_non_tensor_attribute_node[arg_name]
|
||||
fx_block_args.append(arg_node)
|
||||
elif arg_name in self.name_to_non_tensor_attribute:
|
||||
arg_value = self.name_to_non_tensor_attribute[arg_name]
|
||||
fx_block_args.append(arg_value)
|
||||
else:
|
||||
raise ValueError(f"Attribute {arg_name} not found")
|
||||
fx_block_args = [self.get_fx_value_by_fqn(name) for name in arguments]
|
||||
|
||||
args = (
|
||||
predicate,
|
||||
|
|
@ -1036,14 +1177,14 @@ class TS2FXGraphConverter:
|
|||
# currently, _record_function_enter_new and _record_function_exit are
|
||||
# discarded during `retrace_as_exported_program`.
|
||||
target = torch.ops.profiler._record_function_exit
|
||||
args = tuple(self.get_fx_value(input) for input in node.inputs())
|
||||
args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
|
||||
self.fx_graph.call_function(target, args)
|
||||
|
||||
def convert_prim_tolist(self, node: torch._C.Node):
|
||||
# prim::tolist cannot be supported by `_convert_standard_operators`
|
||||
# since it requires call_method instead of call_function.
|
||||
target = "tolist"
|
||||
args = (self.get_fx_value(next(node.inputs())),)
|
||||
args = (self.get_fx_value_by_ir_value(next(node.inputs())),)
|
||||
fx_node = self.fx_graph.call_method(target, args)
|
||||
output_name = node.output().debugName()
|
||||
self.name_to_node[output_name] = fx_node
|
||||
|
|
@ -1058,7 +1199,7 @@ class TS2FXGraphConverter:
|
|||
|
||||
def _convert_standard_operators(self, node: torch._C.Node):
|
||||
target = kind_to_standard_operators[node.kind()]
|
||||
args = tuple(self.get_fx_value(input) for input in node.inputs())
|
||||
args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
|
||||
fx_node = self.fx_graph.call_function(target, args)
|
||||
output_name = node.output().debugName()
|
||||
self.name_to_node[output_name] = fx_node
|
||||
|
|
@ -1084,8 +1225,10 @@ class TS2FXGraphConverter:
|
|||
|
||||
def convert_graph_outputs(self):
|
||||
args = []
|
||||
for graph_output in self.ts_graph.outputs():
|
||||
output_name = graph_output.debugName()
|
||||
outp_name_list = [outp.debugName() for outp in self.ts_graph.outputs()] + list(
|
||||
self.name_update_from_subblock_to_parent
|
||||
)
|
||||
for output_name in outp_name_list:
|
||||
if output_name in self.name_to_node:
|
||||
fx_node = self.name_to_node[output_name]
|
||||
# TODO: Revisit this later after HigherOrderOp design changes.
|
||||
|
|
@ -1118,7 +1261,10 @@ class TS2FXGraphConverter:
|
|||
else:
|
||||
raise ValueError(f"Output {output_name} not found")
|
||||
|
||||
if len(args) == 1:
|
||||
if len(args) == 0:
|
||||
# Sub-block of prim::If can have zero output.
|
||||
self.fx_graph.output([])
|
||||
elif len(args) == 1:
|
||||
self.fx_graph.output(
|
||||
args[0]
|
||||
) # Get rid of an extra list wrapped around final output.
|
||||
|
|
@ -1127,8 +1273,8 @@ class TS2FXGraphConverter:
|
|||
args
|
||||
) # For prim::Loop and prim::If with multiple outputs.
|
||||
else:
|
||||
# Sub-block of prim::If can have zero output.
|
||||
self.fx_graph.output([])
|
||||
# Sub-block of prim::Loop can have multiple outputs.
|
||||
self.fx_graph.output(args)
|
||||
|
||||
|
||||
class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user