diff --git a/test/export/test_converter.py b/test/export/test_converter.py index accd25a86ac..eb6d997b2c5 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -102,10 +102,7 @@ class TestConverter(TestCase): raise RuntimeError(f"Unrecognized mode for torch.jit: {opt}") converter = TS2EPConverter(ts_model, inp) - print(opt, converter.ts_graph) - ep = converter.convert() - print(ep) ep_list.append(ep) for _ in range(num_iterations): @@ -1014,7 +1011,6 @@ class TestConverter(TestCase): ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),)) for ep in ep_list: - print(ep.constants) self.assertEqual(len(ep.constants), 1) def test_aten_tensor_prim_dtype(self): @@ -1327,6 +1323,45 @@ class TestConverter(TestCase): inp = (torch.randn(2, 3),) self._check_equal_ts_ep_converter(M1(), inp, ["script"]) + def test_ts2ep_with_loop(self): + def func1(x, x_list: List[torch.Tensor]): + a, b, c = x, x, x + for i in range(1, 5, 2): + for k in range(5): + a = a + a + k + b = b + b - k + x_list.append(x_list[k] + x_list[k + 1]) + for k in range(5): + b = b + b - k + c = c + c * k + x_list.append(x_list[k] + x_list[k + 1] - x_list[k + 2]) + return x, x_list + + def func2(x): + for i in range(x.size(0)): + x = x * x * i + return x + + def func3(x): + while x.sum() < 10: + x += x.sin() + return x + + inp = ( + torch.tensor(1), + [torch.ones([2, 2]), torch.ones([2, 2]) * 2], + ) + # Trace unrolls the loop. + self._check_equal_ts_ep_converter(func1, inp, ["script"]) + + # TODO: (2/N) + # Trace unrolls the loop. + # self._check_equal_ts_ep_converter(func2, inp, ["script"]) + + # TODO: (3/N) + # Trace unrolls the loop. + # self._check_equal_ts_ep_converter(func3, inp, ["script"]) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 0cacbcb683d..b9f1e205279 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -124,6 +124,26 @@ def list_append(container, element): return container + [element] +def execute_subgraph_from_prim_loop( + subgraph, iter_idx, len_loop_local_arguments, *args, **kwargs +): + """ + subgraph: GraphModule from sub-block. + iter_idx: The index of interation. + len_loop_local_arguments: The number of loop local arguments in args. + """ + + # Loop local variables. TS graph create those as inputs because their values + # are updated inside the loop. + loop_local_args = args[:len_loop_local_arguments] + # Global variables that are not passed in as inputs to the loop sub-blocks + # but are directly used. Most of time, their values are not updated, but + # the only exception is when there are some operations that perform inplace + # updates. + global_args = args[len_loop_local_arguments:] + return subgraph(*global_args, iter_idx, *loop_local_args, **kwargs) + + def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule): def pattern(im, dim, scale): sym_size_int = torch.ops.aten.sym_size.int(im, dim) @@ -415,6 +435,13 @@ class TS2FXGraphConverter: lambda node: self._convert_standard_operators(node), ) + # This stores a list of return results that do not appear in the original TS + # graph's outputs. The reason we maintain this is because some operations in the sub-block + # might have inplace updates to the variable defined in the parent fx graph. After + # the execution of that sub-block, the variable defined in the parent fx graph also + # needs to be updated. + self.name_update_from_subblock_to_parent: Set[str] = set() + def _is_get_attr_node(self, fqn): return ( fqn in self.name_to_buffer @@ -425,6 +452,57 @@ class TS2FXGraphConverter: ) ) + def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]): + subgraph_nodes, subgraph_converters = [], [] + for block in node.blocks(): + subgraph_converter = TS2FXGraphConverter( + block, + self.name_to_param, + self.name_to_buffer, + self.blocks_to_lifted_attrs, + {}, + self.name_to_constant, + ) + subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn + + for block_arg in arguments: + normalized_block_arg_name = normalize_name(block_arg) + placeholder_node = subgraph_converter.fx_graph.placeholder( + normalized_block_arg_name + ) + subgraph_converter.name_to_node[block_arg] = placeholder_node + + subgraph = subgraph_converter.convert() + subgraph_name = self.add_subgraph(subgraph) + subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name)) + subgraph_converters.append(subgraph_converter) + return subgraph_nodes, subgraph_converters + + def _identify_inputs_as_arguments(self, entry): + """ + Identify inputs from the innermost sub-block. This is needed + for nested sub-blocks when the input is hidden in the nested sub-block. + E.g., example IR of input is hidden in the nested sub-block. + Graph[x.1] + %1 = ... + Block[] + Block[x.1] + %2 = x.1 ... + """ + arguments: Set[str] = set() + for block in entry.blocks(): + for block_node in block.nodes(): + for block_node_in in block_node.inputs(): + if ( + block_node_in.debugName() in self.name_to_node + and block_node_in.debugName() not in self.name_to_attribute_fqn + ): + arguments.add(block_node_in.debugName()) + arguments = arguments.union( + self._identify_inputs_as_arguments(block_node) + ) + return arguments + def is_top_level_graph(self): return isinstance(self.ts_graph, torch._C.Graph) @@ -438,13 +516,13 @@ class TS2FXGraphConverter: kwargs = {} for input, schema_arg in zip(node.inputs(), schema.arguments): if schema_arg.kwarg_only: - kwargs[schema_arg.name] = self.get_fx_value(input) + kwargs[schema_arg.name] = self.get_fx_value_by_ir_value(input) else: - args.append(self.get_fx_value(input)) + args.append(self.get_fx_value_by_ir_value(input)) return tuple(args), kwargs - def get_fx_value(self, value: torch._C.Value): + def get_fx_value_by_ir_value(self, value: torch._C.Value): value_name = value.debugName() if value_name in self.name_to_node: @@ -457,6 +535,19 @@ class TS2FXGraphConverter: else: raise ValueError(f"Input {value_name} not found") + def get_fx_value_by_fqn(self, name): + if name in self.name_to_node: + fx_node = self.name_to_node[name] + elif name in self.name_to_constant: + fx_node = self.name_to_constant[name] + elif name in self.name_to_non_tensor_attribute_node: + fx_node = self.name_to_non_tensor_attribute_node[name] + elif name in self.name_to_non_tensor_attribute: + fx_node = self.name_to_non_tensor_attribute[name] + else: + raise ValueError(f"Attribute {name} not found") + return fx_node + def convert(self) -> torch.fx.GraphModule: self.convert_graph_inputs() @@ -585,13 +676,17 @@ class TS2FXGraphConverter: "This makes the converter non-functional: the result depends on the order of the append nodes being converter!" ) - args = tuple(self.get_fx_value(inp) for inp in node.inputs()) + args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs()) fx_node = self.fx_graph.call_function(list_append, args) self.name_to_node[node.output().debugName()] = fx_node # inplace mutate arg[0], which is the python list self.name_to_node[node.inputsAt(0).debugName()] = fx_node + # Variables that need to be updated to parent module. + if not self.is_top_level_graph() and args[0].op == "placeholder": + self.name_update_from_subblock_to_parent.add(node.inputsAt(0).debugName()) + def convert_prim_Constant(self, node: torch._C.Node): name = node.output().debugName() @@ -621,7 +716,9 @@ class TS2FXGraphConverter: self.name_to_constant[name] = value def convert_prim_CallMethod(self, node: torch._C.Node): - inp_list = [self.get_fx_value(inp) for inp in node.inputs()] # noqa: C416 + inp_list = [ + self.get_fx_value_by_ir_value(inp) for inp in node.inputs() + ] # noqa: C416 fx_node = self.fx_graph.call_method( node.s("name"), tuple(inp_list), @@ -668,7 +765,7 @@ class TS2FXGraphConverter: def convert_prim_SetAttr(self, node: torch._C.Node): attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node) attr_value = tuple(node.inputs())[1] - ts_graph_tensor_input = self.get_fx_value(attr_value) + ts_graph_tensor_input = self.get_fx_value_by_ir_value(attr_value) if self._is_get_attr_node(attr_fqn): fx_attr_node = self.fx_graph.get_attr(attr_fqn) self.fx_graph.call_function( @@ -707,7 +804,7 @@ class TS2FXGraphConverter: def _convert_prim_iterator(self, node: torch._C.Node): output_list = [] for inp in node.inputs(): - output_list.append(self.get_fx_value(inp)) + output_list.append(self.get_fx_value_by_ir_value(inp)) output_name = node.output().debugName() self.name_to_node[output_name] = output_list @@ -719,9 +816,9 @@ class TS2FXGraphConverter: # We assume key value are stored in pair in the DictConstruct. # The first element is the key and the following is the value. if i % 2 == 0: - k = self.get_fx_value(inp) + k = self.get_fx_value_by_ir_value(inp) else: - v = self.get_fx_value(inp) + v = self.get_fx_value_by_ir_value(inp) assert ( k is not None and v is not None ), "DictConstruct has an empty key value pair." @@ -745,14 +842,14 @@ class TS2FXGraphConverter: # Single input and multiple outputs for unpacking. for i, outp in enumerate(node.outputs()): outp_name = outp.debugName() - inp = self.get_fx_value(node.input()) + inp = self.get_fx_value_by_ir_value(node.input()) fx_node = self.fx_graph.call_function(operator.getitem, (inp, i)) self.name_to_node[outp_name] = fx_node def convert_aten_Int(self, node: torch._C.Node): # converts aten::Int as aten._to_copy + aten::_local_scalar_dense target = torch.ops.aten._to_copy.default - args = tuple(self.get_fx_value(input) for input in node.inputs()) + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32}) fx_node = self.fx_graph.call_function( @@ -773,7 +870,7 @@ class TS2FXGraphConverter: # For both of those APIs, torch.jit.trace implicitly sets the output tensor type # to be LongTensor. target = torch.ops.aten.scalar_tensor - args = tuple(self.get_fx_value(input) for input in node.inputs()) + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) fx_node = self.fx_graph.call_function(target, args, {"dtype": torch.long}) output_name = node.output().debugName() @@ -829,7 +926,7 @@ class TS2FXGraphConverter: def convert_aten___getitem__(self, node: torch._C.Node): input_container, index = tuple( - self.get_fx_value(input) for input in node.inputs() + self.get_fx_value_by_ir_value(input) for input in node.inputs() ) fx_node = self.fx_graph.call_function( operator.getitem, (input_container, index) @@ -895,6 +992,111 @@ class TS2FXGraphConverter: else: self.convert_call_function_op(node) + def _check_prim_loop_support(self, node): + inputs = list(node.inputs()) + + # TODO: (1/N) stage. + if inputs[0].debugName() not in self.name_to_constant: + raise RuntimeError( + "prim::Loop currently cannot run with dynamic value of number of iterations." + ) + + # Make sure the condition is not updated in the subblock. + subblock = next(node.blocks()) + condition_output_name = next(subblock.outputs()).debugName() + for node in subblock.nodes(): + if ( + node.outputsSize() == 1 + and node.output().debugName() == condition_output_name + ): + raise RuntimeError( + "prim::Loop currently cannot run with dynamic value of condition." + ) + if node.outputsSize() >= 2: + for outp in node.outputs(): + if outp.debugName() == condition_output_name: + raise RuntimeError( + "prim::Loop currently cannot run with dynamic value of condition." + ) + + def convert_prim_Loop(self, node: torch._C.Node): + inputs = list(node.inputs()) + self._check_prim_loop_support(node) + + num_iterations = self.get_fx_value_by_ir_value(inputs[0]) + + # Find inputs. + loop_local_arguments = [inp.debugName() for inp in inputs[2:]] + + global_arguments = self._identify_inputs_as_arguments(node) + + # Lift parameters as inputs. + for block in node.blocks(): + global_arguments = global_arguments.union( + self.blocks_to_lifted_attrs[block] + ) + + global_arguments = list(global_arguments) + + subgraph_nodes, subgraph_converters = self._convert_block_to_subgraph( + node, global_arguments + ) + + assert len(subgraph_nodes) == 1 + subgraph_converter = subgraph_converters[0] + if not self.is_top_level_graph(): + self.name_update_from_subblock_to_parent = ( + self.name_update_from_subblock_to_parent.union( + subgraph_converter.name_update_from_subblock_to_parent + ) + ) + + fx_block_args = [ + self.get_fx_value_by_fqn(name) + for name in loop_local_arguments + global_arguments + ] + for iter_idx in range(num_iterations): + loop_node = self.fx_graph.call_function( + execute_subgraph_from_prim_loop, + # Check execute_node function for the expected arguments order. + ( + subgraph_nodes[0], + iter_idx, + len(loop_local_arguments), + *fx_block_args, + ), + {}, + ) + + # Update the value of loop local variables. + if node.outputsSize() >= 1: + for i, outp in enumerate(node.outputs()): + output_name = outp.debugName() + self.name_to_node[output_name] = self.fx_graph.call_function( + operator.getitem, + ( + loop_node, + i + 1, + ), # + 1 because the 0th element is the condition. + ) + fx_block_args[i] = self.name_to_node[output_name] + + # Update the value of global variables, whose values are modified inplace. + for i, name in enumerate( + subgraph_converter.name_update_from_subblock_to_parent + ): + self.name_to_node[name] = self.fx_graph.call_function( + operator.getitem, + ( + loop_node, + i + node.outputsSize() + 1, + ), # + 1 because the 0th element is the condition. + ) + global_argument_index = global_arguments.index(name) + fx_block_args[ + i + node.outputsSize() + global_argument_index + ] = self.name_to_node[name] + def _check_set_attr_in_if_block(self, if_node: torch._C.Node): for block in if_node.blocks(): for node in block.nodes(): @@ -910,82 +1112,21 @@ class TS2FXGraphConverter: inputs = list(node.inputs()) assert len(inputs) == 1 - predicate = self.get_fx_value(inputs[0]) - - def _identify_inputs_as_arguments(entry): - """ - Identify inputs from the innermost sub-block. This is needed - for nested sub-blocks when the input is hidden in the nested sub-block. - E.g., example IR of input is hidden in the nested sub-block. - Graph[x.1] - %1 = ... - Block[] - Block[x.1] - %2 = x.1 ... - """ - arguments: Set[str] = set() - for block in entry.blocks(): - for block_node in block.nodes(): - for block_node_in in block_node.inputs(): - if ( - block_node_in.debugName() in self.name_to_node - and block_node_in.debugName() - not in self.name_to_attribute_fqn - ): - arguments.add(block_node_in.debugName()) - arguments = arguments.union( - _identify_inputs_as_arguments(block_node) - ) - return arguments + predicate = self.get_fx_value_by_ir_value(inputs[0]) # Find inputs. - arguments = _identify_inputs_as_arguments(node) + arguments = self._identify_inputs_as_arguments(node) # Lift parameters as inputs. for block in node.blocks(): arguments = arguments.union(self.blocks_to_lifted_attrs[block]) arguments = list(arguments) - - # Convert blocks to subgraphs - subgraph_nodes = [] - for block in node.blocks(): - subgraph_converter = TS2FXGraphConverter( - block, - self.name_to_param, - self.name_to_buffer, - self.blocks_to_lifted_attrs, - {}, - self.name_to_constant, - ) - subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn - - for block_arg in arguments: - normalized_block_arg_name = normalize_name(block_arg) - placeholder_node = subgraph_converter.fx_graph.placeholder( - normalized_block_arg_name - ) - subgraph_converter.name_to_node[block_arg] = placeholder_node - - subgraph = subgraph_converter.convert() - subgraph_name = self.add_subgraph(subgraph) - subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name)) + subgraph_nodes, _ = self._convert_block_to_subgraph(node, arguments) assert len(subgraph_nodes) == 2 - fx_block_args = [] - for arg_name in arguments: - if arg_name in self.name_to_node: - arg_node = self.name_to_node[arg_name] - fx_block_args.append(arg_node) - elif arg_name in self.name_to_non_tensor_attribute_node: - arg_node = self.name_to_non_tensor_attribute_node[arg_name] - fx_block_args.append(arg_node) - elif arg_name in self.name_to_non_tensor_attribute: - arg_value = self.name_to_non_tensor_attribute[arg_name] - fx_block_args.append(arg_value) - else: - raise ValueError(f"Attribute {arg_name} not found") + fx_block_args = [self.get_fx_value_by_fqn(name) for name in arguments] args = ( predicate, @@ -1036,14 +1177,14 @@ class TS2FXGraphConverter: # currently, _record_function_enter_new and _record_function_exit are # discarded during `retrace_as_exported_program`. target = torch.ops.profiler._record_function_exit - args = tuple(self.get_fx_value(input) for input in node.inputs()) + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) self.fx_graph.call_function(target, args) def convert_prim_tolist(self, node: torch._C.Node): # prim::tolist cannot be supported by `_convert_standard_operators` # since it requires call_method instead of call_function. target = "tolist" - args = (self.get_fx_value(next(node.inputs())),) + args = (self.get_fx_value_by_ir_value(next(node.inputs())),) fx_node = self.fx_graph.call_method(target, args) output_name = node.output().debugName() self.name_to_node[output_name] = fx_node @@ -1058,7 +1199,7 @@ class TS2FXGraphConverter: def _convert_standard_operators(self, node: torch._C.Node): target = kind_to_standard_operators[node.kind()] - args = tuple(self.get_fx_value(input) for input in node.inputs()) + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) fx_node = self.fx_graph.call_function(target, args) output_name = node.output().debugName() self.name_to_node[output_name] = fx_node @@ -1084,8 +1225,10 @@ class TS2FXGraphConverter: def convert_graph_outputs(self): args = [] - for graph_output in self.ts_graph.outputs(): - output_name = graph_output.debugName() + outp_name_list = [outp.debugName() for outp in self.ts_graph.outputs()] + list( + self.name_update_from_subblock_to_parent + ) + for output_name in outp_name_list: if output_name in self.name_to_node: fx_node = self.name_to_node[output_name] # TODO: Revisit this later after HigherOrderOp design changes. @@ -1118,7 +1261,10 @@ class TS2FXGraphConverter: else: raise ValueError(f"Output {output_name} not found") - if len(args) == 1: + if len(args) == 0: + # Sub-block of prim::If can have zero output. + self.fx_graph.output([]) + elif len(args) == 1: self.fx_graph.output( args[0] ) # Get rid of an extra list wrapped around final output. @@ -1127,8 +1273,8 @@ class TS2FXGraphConverter: args ) # For prim::Loop and prim::If with multiple outputs. else: - # Sub-block of prim::If can have zero output. - self.fx_graph.output([]) + # Sub-block of prim::Loop can have multiple outputs. + self.fx_graph.output(args) class ExplainTS2FXGraphConverter(TS2FXGraphConverter):