[ts-migration][1/N]: Add prim::Loop for constant number of iterations and condition (#131418)

#### Description
This PR adds prim::Loop support for the simplest case where the number of iteration is constant and the loop termination condition is also a constant.

[PR by stages](https://docs.google.com/document/d/1q6OprW3HBHbYPwEyE_DikBn-uzmhnN284Cmen_CnlhI/edit?usp=sharing)

#### Test Plan
Add reprod example.
* `pytest test/export/test_converter.py -s -k test_ts2ep_with_loop`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131418
Approved by: https://github.com/angelayi
This commit is contained in:
Jiashen Cao 2024-08-06 16:51:08 +00:00 committed by PyTorch MergeBot
parent c803e35c4b
commit ca7ce2fca1
2 changed files with 271 additions and 90 deletions

View File

@ -102,10 +102,7 @@ class TestConverter(TestCase):
raise RuntimeError(f"Unrecognized mode for torch.jit: {opt}")
converter = TS2EPConverter(ts_model, inp)
print(opt, converter.ts_graph)
ep = converter.convert()
print(ep)
ep_list.append(ep)
for _ in range(num_iterations):
@ -1014,7 +1011,6 @@ class TestConverter(TestCase):
ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),))
for ep in ep_list:
print(ep.constants)
self.assertEqual(len(ep.constants), 1)
def test_aten_tensor_prim_dtype(self):
@ -1327,6 +1323,45 @@ class TestConverter(TestCase):
inp = (torch.randn(2, 3),)
self._check_equal_ts_ep_converter(M1(), inp, ["script"])
def test_ts2ep_with_loop(self):
def func1(x, x_list: List[torch.Tensor]):
a, b, c = x, x, x
for i in range(1, 5, 2):
for k in range(5):
a = a + a + k
b = b + b - k
x_list.append(x_list[k] + x_list[k + 1])
for k in range(5):
b = b + b - k
c = c + c * k
x_list.append(x_list[k] + x_list[k + 1] - x_list[k + 2])
return x, x_list
def func2(x):
for i in range(x.size(0)):
x = x * x * i
return x
def func3(x):
while x.sum() < 10:
x += x.sin()
return x
inp = (
torch.tensor(1),
[torch.ones([2, 2]), torch.ones([2, 2]) * 2],
)
# Trace unrolls the loop.
self._check_equal_ts_ep_converter(func1, inp, ["script"])
# TODO: (2/N)
# Trace unrolls the loop.
# self._check_equal_ts_ep_converter(func2, inp, ["script"])
# TODO: (3/N)
# Trace unrolls the loop.
# self._check_equal_ts_ep_converter(func3, inp, ["script"])
if __name__ == "__main__":
run_tests()

View File

@ -124,6 +124,26 @@ def list_append(container, element):
return container + [element]
def execute_subgraph_from_prim_loop(
subgraph, iter_idx, len_loop_local_arguments, *args, **kwargs
):
"""
subgraph: GraphModule from sub-block.
iter_idx: The index of interation.
len_loop_local_arguments: The number of loop local arguments in args.
"""
# Loop local variables. TS graph create those as inputs because their values
# are updated inside the loop.
loop_local_args = args[:len_loop_local_arguments]
# Global variables that are not passed in as inputs to the loop sub-blocks
# but are directly used. Most of time, their values are not updated, but
# the only exception is when there are some operations that perform inplace
# updates.
global_args = args[len_loop_local_arguments:]
return subgraph(*global_args, iter_idx, *loop_local_args, **kwargs)
def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule):
def pattern(im, dim, scale):
sym_size_int = torch.ops.aten.sym_size.int(im, dim)
@ -415,6 +435,13 @@ class TS2FXGraphConverter:
lambda node: self._convert_standard_operators(node),
)
# This stores a list of return results that do not appear in the original TS
# graph's outputs. The reason we maintain this is because some operations in the sub-block
# might have inplace updates to the variable defined in the parent fx graph. After
# the execution of that sub-block, the variable defined in the parent fx graph also
# needs to be updated.
self.name_update_from_subblock_to_parent: Set[str] = set()
def _is_get_attr_node(self, fqn):
return (
fqn in self.name_to_buffer
@ -425,6 +452,57 @@ class TS2FXGraphConverter:
)
)
def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]):
subgraph_nodes, subgraph_converters = [], []
for block in node.blocks():
subgraph_converter = TS2FXGraphConverter(
block,
self.name_to_param,
self.name_to_buffer,
self.blocks_to_lifted_attrs,
{},
self.name_to_constant,
)
subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn
for block_arg in arguments:
normalized_block_arg_name = normalize_name(block_arg)
placeholder_node = subgraph_converter.fx_graph.placeholder(
normalized_block_arg_name
)
subgraph_converter.name_to_node[block_arg] = placeholder_node
subgraph = subgraph_converter.convert()
subgraph_name = self.add_subgraph(subgraph)
subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name))
subgraph_converters.append(subgraph_converter)
return subgraph_nodes, subgraph_converters
def _identify_inputs_as_arguments(self, entry):
"""
Identify inputs from the innermost sub-block. This is needed
for nested sub-blocks when the input is hidden in the nested sub-block.
E.g., example IR of input is hidden in the nested sub-block.
Graph[x.1]
%1 = ...
Block[]
Block[x.1]
%2 = x.1 ...
"""
arguments: Set[str] = set()
for block in entry.blocks():
for block_node in block.nodes():
for block_node_in in block_node.inputs():
if (
block_node_in.debugName() in self.name_to_node
and block_node_in.debugName() not in self.name_to_attribute_fqn
):
arguments.add(block_node_in.debugName())
arguments = arguments.union(
self._identify_inputs_as_arguments(block_node)
)
return arguments
def is_top_level_graph(self):
return isinstance(self.ts_graph, torch._C.Graph)
@ -438,13 +516,13 @@ class TS2FXGraphConverter:
kwargs = {}
for input, schema_arg in zip(node.inputs(), schema.arguments):
if schema_arg.kwarg_only:
kwargs[schema_arg.name] = self.get_fx_value(input)
kwargs[schema_arg.name] = self.get_fx_value_by_ir_value(input)
else:
args.append(self.get_fx_value(input))
args.append(self.get_fx_value_by_ir_value(input))
return tuple(args), kwargs
def get_fx_value(self, value: torch._C.Value):
def get_fx_value_by_ir_value(self, value: torch._C.Value):
value_name = value.debugName()
if value_name in self.name_to_node:
@ -457,6 +535,19 @@ class TS2FXGraphConverter:
else:
raise ValueError(f"Input {value_name} not found")
def get_fx_value_by_fqn(self, name):
if name in self.name_to_node:
fx_node = self.name_to_node[name]
elif name in self.name_to_constant:
fx_node = self.name_to_constant[name]
elif name in self.name_to_non_tensor_attribute_node:
fx_node = self.name_to_non_tensor_attribute_node[name]
elif name in self.name_to_non_tensor_attribute:
fx_node = self.name_to_non_tensor_attribute[name]
else:
raise ValueError(f"Attribute {name} not found")
return fx_node
def convert(self) -> torch.fx.GraphModule:
self.convert_graph_inputs()
@ -585,13 +676,17 @@ class TS2FXGraphConverter:
"This makes the converter non-functional: the result depends on the order of the append nodes being converter!"
)
args = tuple(self.get_fx_value(inp) for inp in node.inputs())
args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs())
fx_node = self.fx_graph.call_function(list_append, args)
self.name_to_node[node.output().debugName()] = fx_node
# inplace mutate arg[0], which is the python list
self.name_to_node[node.inputsAt(0).debugName()] = fx_node
# Variables that need to be updated to parent module.
if not self.is_top_level_graph() and args[0].op == "placeholder":
self.name_update_from_subblock_to_parent.add(node.inputsAt(0).debugName())
def convert_prim_Constant(self, node: torch._C.Node):
name = node.output().debugName()
@ -621,7 +716,9 @@ class TS2FXGraphConverter:
self.name_to_constant[name] = value
def convert_prim_CallMethod(self, node: torch._C.Node):
inp_list = [self.get_fx_value(inp) for inp in node.inputs()] # noqa: C416
inp_list = [
self.get_fx_value_by_ir_value(inp) for inp in node.inputs()
] # noqa: C416
fx_node = self.fx_graph.call_method(
node.s("name"),
tuple(inp_list),
@ -668,7 +765,7 @@ class TS2FXGraphConverter:
def convert_prim_SetAttr(self, node: torch._C.Node):
attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node)
attr_value = tuple(node.inputs())[1]
ts_graph_tensor_input = self.get_fx_value(attr_value)
ts_graph_tensor_input = self.get_fx_value_by_ir_value(attr_value)
if self._is_get_attr_node(attr_fqn):
fx_attr_node = self.fx_graph.get_attr(attr_fqn)
self.fx_graph.call_function(
@ -707,7 +804,7 @@ class TS2FXGraphConverter:
def _convert_prim_iterator(self, node: torch._C.Node):
output_list = []
for inp in node.inputs():
output_list.append(self.get_fx_value(inp))
output_list.append(self.get_fx_value_by_ir_value(inp))
output_name = node.output().debugName()
self.name_to_node[output_name] = output_list
@ -719,9 +816,9 @@ class TS2FXGraphConverter:
# We assume key value are stored in pair in the DictConstruct.
# The first element is the key and the following is the value.
if i % 2 == 0:
k = self.get_fx_value(inp)
k = self.get_fx_value_by_ir_value(inp)
else:
v = self.get_fx_value(inp)
v = self.get_fx_value_by_ir_value(inp)
assert (
k is not None and v is not None
), "DictConstruct has an empty key value pair."
@ -745,14 +842,14 @@ class TS2FXGraphConverter:
# Single input and multiple outputs for unpacking.
for i, outp in enumerate(node.outputs()):
outp_name = outp.debugName()
inp = self.get_fx_value(node.input())
inp = self.get_fx_value_by_ir_value(node.input())
fx_node = self.fx_graph.call_function(operator.getitem, (inp, i))
self.name_to_node[outp_name] = fx_node
def convert_aten_Int(self, node: torch._C.Node):
# converts aten::Int as aten._to_copy + aten::_local_scalar_dense
target = torch.ops.aten._to_copy.default
args = tuple(self.get_fx_value(input) for input in node.inputs())
args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32})
fx_node = self.fx_graph.call_function(
@ -773,7 +870,7 @@ class TS2FXGraphConverter:
# For both of those APIs, torch.jit.trace implicitly sets the output tensor type
# to be LongTensor.
target = torch.ops.aten.scalar_tensor
args = tuple(self.get_fx_value(input) for input in node.inputs())
args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
fx_node = self.fx_graph.call_function(target, args, {"dtype": torch.long})
output_name = node.output().debugName()
@ -829,7 +926,7 @@ class TS2FXGraphConverter:
def convert_aten___getitem__(self, node: torch._C.Node):
input_container, index = tuple(
self.get_fx_value(input) for input in node.inputs()
self.get_fx_value_by_ir_value(input) for input in node.inputs()
)
fx_node = self.fx_graph.call_function(
operator.getitem, (input_container, index)
@ -895,6 +992,111 @@ class TS2FXGraphConverter:
else:
self.convert_call_function_op(node)
def _check_prim_loop_support(self, node):
inputs = list(node.inputs())
# TODO: (1/N) stage.
if inputs[0].debugName() not in self.name_to_constant:
raise RuntimeError(
"prim::Loop currently cannot run with dynamic value of number of iterations."
)
# Make sure the condition is not updated in the subblock.
subblock = next(node.blocks())
condition_output_name = next(subblock.outputs()).debugName()
for node in subblock.nodes():
if (
node.outputsSize() == 1
and node.output().debugName() == condition_output_name
):
raise RuntimeError(
"prim::Loop currently cannot run with dynamic value of condition."
)
if node.outputsSize() >= 2:
for outp in node.outputs():
if outp.debugName() == condition_output_name:
raise RuntimeError(
"prim::Loop currently cannot run with dynamic value of condition."
)
def convert_prim_Loop(self, node: torch._C.Node):
inputs = list(node.inputs())
self._check_prim_loop_support(node)
num_iterations = self.get_fx_value_by_ir_value(inputs[0])
# Find inputs.
loop_local_arguments = [inp.debugName() for inp in inputs[2:]]
global_arguments = self._identify_inputs_as_arguments(node)
# Lift parameters as inputs.
for block in node.blocks():
global_arguments = global_arguments.union(
self.blocks_to_lifted_attrs[block]
)
global_arguments = list(global_arguments)
subgraph_nodes, subgraph_converters = self._convert_block_to_subgraph(
node, global_arguments
)
assert len(subgraph_nodes) == 1
subgraph_converter = subgraph_converters[0]
if not self.is_top_level_graph():
self.name_update_from_subblock_to_parent = (
self.name_update_from_subblock_to_parent.union(
subgraph_converter.name_update_from_subblock_to_parent
)
)
fx_block_args = [
self.get_fx_value_by_fqn(name)
for name in loop_local_arguments + global_arguments
]
for iter_idx in range(num_iterations):
loop_node = self.fx_graph.call_function(
execute_subgraph_from_prim_loop,
# Check execute_node function for the expected arguments order.
(
subgraph_nodes[0],
iter_idx,
len(loop_local_arguments),
*fx_block_args,
),
{},
)
# Update the value of loop local variables.
if node.outputsSize() >= 1:
for i, outp in enumerate(node.outputs()):
output_name = outp.debugName()
self.name_to_node[output_name] = self.fx_graph.call_function(
operator.getitem,
(
loop_node,
i + 1,
), # + 1 because the 0th element is the condition.
)
fx_block_args[i] = self.name_to_node[output_name]
# Update the value of global variables, whose values are modified inplace.
for i, name in enumerate(
subgraph_converter.name_update_from_subblock_to_parent
):
self.name_to_node[name] = self.fx_graph.call_function(
operator.getitem,
(
loop_node,
i + node.outputsSize() + 1,
), # + 1 because the 0th element is the condition.
)
global_argument_index = global_arguments.index(name)
fx_block_args[
i + node.outputsSize() + global_argument_index
] = self.name_to_node[name]
def _check_set_attr_in_if_block(self, if_node: torch._C.Node):
for block in if_node.blocks():
for node in block.nodes():
@ -910,82 +1112,21 @@ class TS2FXGraphConverter:
inputs = list(node.inputs())
assert len(inputs) == 1
predicate = self.get_fx_value(inputs[0])
def _identify_inputs_as_arguments(entry):
"""
Identify inputs from the innermost sub-block. This is needed
for nested sub-blocks when the input is hidden in the nested sub-block.
E.g., example IR of input is hidden in the nested sub-block.
Graph[x.1]
%1 = ...
Block[]
Block[x.1]
%2 = x.1 ...
"""
arguments: Set[str] = set()
for block in entry.blocks():
for block_node in block.nodes():
for block_node_in in block_node.inputs():
if (
block_node_in.debugName() in self.name_to_node
and block_node_in.debugName()
not in self.name_to_attribute_fqn
):
arguments.add(block_node_in.debugName())
arguments = arguments.union(
_identify_inputs_as_arguments(block_node)
)
return arguments
predicate = self.get_fx_value_by_ir_value(inputs[0])
# Find inputs.
arguments = _identify_inputs_as_arguments(node)
arguments = self._identify_inputs_as_arguments(node)
# Lift parameters as inputs.
for block in node.blocks():
arguments = arguments.union(self.blocks_to_lifted_attrs[block])
arguments = list(arguments)
# Convert blocks to subgraphs
subgraph_nodes = []
for block in node.blocks():
subgraph_converter = TS2FXGraphConverter(
block,
self.name_to_param,
self.name_to_buffer,
self.blocks_to_lifted_attrs,
{},
self.name_to_constant,
)
subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn
for block_arg in arguments:
normalized_block_arg_name = normalize_name(block_arg)
placeholder_node = subgraph_converter.fx_graph.placeholder(
normalized_block_arg_name
)
subgraph_converter.name_to_node[block_arg] = placeholder_node
subgraph = subgraph_converter.convert()
subgraph_name = self.add_subgraph(subgraph)
subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name))
subgraph_nodes, _ = self._convert_block_to_subgraph(node, arguments)
assert len(subgraph_nodes) == 2
fx_block_args = []
for arg_name in arguments:
if arg_name in self.name_to_node:
arg_node = self.name_to_node[arg_name]
fx_block_args.append(arg_node)
elif arg_name in self.name_to_non_tensor_attribute_node:
arg_node = self.name_to_non_tensor_attribute_node[arg_name]
fx_block_args.append(arg_node)
elif arg_name in self.name_to_non_tensor_attribute:
arg_value = self.name_to_non_tensor_attribute[arg_name]
fx_block_args.append(arg_value)
else:
raise ValueError(f"Attribute {arg_name} not found")
fx_block_args = [self.get_fx_value_by_fqn(name) for name in arguments]
args = (
predicate,
@ -1036,14 +1177,14 @@ class TS2FXGraphConverter:
# currently, _record_function_enter_new and _record_function_exit are
# discarded during `retrace_as_exported_program`.
target = torch.ops.profiler._record_function_exit
args = tuple(self.get_fx_value(input) for input in node.inputs())
args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
self.fx_graph.call_function(target, args)
def convert_prim_tolist(self, node: torch._C.Node):
# prim::tolist cannot be supported by `_convert_standard_operators`
# since it requires call_method instead of call_function.
target = "tolist"
args = (self.get_fx_value(next(node.inputs())),)
args = (self.get_fx_value_by_ir_value(next(node.inputs())),)
fx_node = self.fx_graph.call_method(target, args)
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
@ -1058,7 +1199,7 @@ class TS2FXGraphConverter:
def _convert_standard_operators(self, node: torch._C.Node):
target = kind_to_standard_operators[node.kind()]
args = tuple(self.get_fx_value(input) for input in node.inputs())
args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
fx_node = self.fx_graph.call_function(target, args)
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
@ -1084,8 +1225,10 @@ class TS2FXGraphConverter:
def convert_graph_outputs(self):
args = []
for graph_output in self.ts_graph.outputs():
output_name = graph_output.debugName()
outp_name_list = [outp.debugName() for outp in self.ts_graph.outputs()] + list(
self.name_update_from_subblock_to_parent
)
for output_name in outp_name_list:
if output_name in self.name_to_node:
fx_node = self.name_to_node[output_name]
# TODO: Revisit this later after HigherOrderOp design changes.
@ -1118,7 +1261,10 @@ class TS2FXGraphConverter:
else:
raise ValueError(f"Output {output_name} not found")
if len(args) == 1:
if len(args) == 0:
# Sub-block of prim::If can have zero output.
self.fx_graph.output([])
elif len(args) == 1:
self.fx_graph.output(
args[0]
) # Get rid of an extra list wrapped around final output.
@ -1127,8 +1273,8 @@ class TS2FXGraphConverter:
args
) # For prim::Loop and prim::If with multiple outputs.
else:
# Sub-block of prim::If can have zero output.
self.fx_graph.output([])
# Sub-block of prim::Loop can have multiple outputs.
self.fx_graph.output(args)
class ExplainTS2FXGraphConverter(TS2FXGraphConverter):