pytorch/torch/_export/converter.py
Jiashen Cao d66f12674c Handle tuple and dict during TorchScript to ExportedProgram conversion (#127341)
* Add some test cases for testing List, Tuple, and Dict
* Refactor the conversion code slightly
* Add a logic to handle Dict
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127341
Approved by: https://github.com/SherlockNoMad, https://github.com/angelayi
2024-05-30 00:08:09 +00:00

398 lines
14 KiB
Python

from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.export._trace
from torch.export.exported_program import ExportedProgram
from torch.export.graph_signature import (
InputKind,
InputSpec,
OutputKind,
OutputSpec,
TensorArgument,
)
from torch.fx import subgraph_rewriter
from torch.onnx.utils import _create_jit_graph
from torchgen.model import FunctionSchema
def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule):
def pattern(im, dim, scale):
sym_size_int = torch.ops.aten.sym_size.int(im, dim)
scalar_tensor = torch.ops.aten.scalar_tensor(sym_size_int)
div_scalar_mode = torch.ops.aten.div.Scalar_mode(
scalar_tensor, scale, rounding_mode="trunc"
)
int_tensor = torch.ops.aten.Int.Tensor(div_scalar_mode)
return int_tensor
def replacement(im, dim, scale):
sym_size_int = torch.ops.aten.sym_size.int(im, dim)
return sym_size_int // scale
replaced_patterns = subgraph_rewriter.replace_pattern(gm, pattern, replacement)
def normalize_name(name: str) -> str:
return name.replace(".", "_")
def get_op_overload(node: torch._C.Node):
schema_str = node.schema()
schema = FunctionSchema.parse(schema_str)
ns, op_name = str(schema.name.name).split("::")
override = schema.name.overload_name
op_overload_packet = getattr(torch.ops.aten, op_name)
if override:
op_overload = getattr(op_overload_packet, override)
else:
op_overload = op_overload_packet.default
return op_overload
class TS2EPConverter:
# TorchScript model to ExportedProgram converter
def __init__(
self,
ts_model,
sample_args: Tuple[Any, ...],
sample_kwargs: Optional[Dict[str, Any]] = None,
):
self.ts_model = ts_model
self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)
self.sample_args = sample_args
self.sample_kwargs = sample_kwargs
self.param_names: Set[str] = {name for name, _ in ts_model.named_parameters()}
self.buffer_names: Set[str] = {name for name, _ in ts_model.named_buffers()}
self.fx_graph: torch.fx.Graph = torch.fx.Graph()
self.input_specs: List[InputSpec] = []
self.output_specs: List[OutputSpec] = []
self.name_to_node: Dict[
str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]]
] = {}
self.constant_map: Dict[str, Any] = {}
self.attribute_map: Dict[str, Any] = {}
self.tensor_constants: Dict[str, torch.Tensor] = {}
def get_args_kwargs(self, node: torch._C.Node, schema):
args = []
kwargs = {}
for input, schema_arg in zip(node.inputs(), schema.arguments):
if schema_arg.kwarg_only:
kwargs[schema_arg.name] = self.get_fx_value(input)
else:
args.append(self.get_fx_value(input))
return tuple(args), kwargs
def get_fx_value(self, value: torch._C.Value):
value_name = value.debugName()
if value_name in self.name_to_node:
input_node = self.name_to_node[value_name]
return input_node
elif value_name in self.attribute_map:
attr_name = self.attribute_map[value_name]
if attr_name in self.name_to_node:
input_node = self.name_to_node[attr_name]
return input_node
else:
raise ValueError(f"Value {attr_name} not found")
elif value_name in self.constant_map:
return self.constant_map[value_name]
else:
raise ValueError(f"Input {value_name} not found")
def convert(self) -> ExportedProgram:
self.convert_graph_inputs()
for node in self.ts_graph.nodes():
self.convert_node(node)
self.convert_graph_outputs()
gm = torch.fx.GraphModule({}, self.fx_graph)
inplace_optimize_sym_size_div(gm)
gm.graph.lint()
ep = self.retrace_as_exported_program(gm)
return ep
def convert_graph_inputs(self):
for graph_input in self.ts_graph.inputs():
name = graph_input.debugName()
normalized_name = normalize_name(name)
fx_node = self.fx_graph.placeholder(normalized_name)
# fx_node.meta["val"] = FakeTensor()
# TODO: set fx_node.meta["val"]
self.name_to_node[name] = fx_node
if name in self.param_names:
self.input_specs.append(
InputSpec(
InputKind.PARAMETER,
arg=TensorArgument(name=normalized_name),
target=name,
)
)
elif name in self.buffer_names:
self.input_specs.append(
InputSpec(
InputKind.BUFFER,
arg=TensorArgument(name=normalized_name),
target=name,
persistent=True,
)
)
else:
self.input_specs.append(
InputSpec(
InputKind.USER_INPUT,
arg=TensorArgument(name=normalized_name),
target=name,
)
)
def convert_prim_Constant(self, node: torch._C.Node):
name = node.output().debugName()
value: Any = None
if node.hasAttribute("value"):
constant_kind = node.kindOf("value")
if constant_kind == "i":
value = node.i("value")
elif constant_kind == "f":
value = node.f("value")
elif constant_kind == "s":
value = node.s("value")
elif constant_kind == "t":
# lift tensor constant as a placeholder
placeholder_name = f"constant_{name}"
fx_node = self.fx_graph.placeholder(placeholder_name)
self.name_to_node[name] = fx_node
self.tensor_constants[placeholder_name] = node.t("value")
self.input_specs.append(
InputSpec(
InputKind.CONSTANT_TENSOR,
arg=TensorArgument(name=placeholder_name),
target=placeholder_name,
)
)
value = fx_node
elif constant_kind == "ival":
value = node.ival("value")
else:
raise ValueError(f"Unsupported constant type: {node.kindOf('value')}")
else:
value = None
self.constant_map[name] = value
def convert_prim_GetAttr(self, node: torch._C.Node):
def get_attr(name: str):
if name in self.attribute_map:
return self.attribute_map[name]
else:
raise ValueError(f"Attribute {name} not found")
output_name = node.output().debugName()
attr_name = node.s("name")
input_name = node.input().debugName()
root_attr_name = get_attr(input_name)
self.attribute_map[output_name] = (
f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name
)
def convert_aten_op(self, node: torch._C.Node):
target = get_op_overload(node)
if target is torch.ops.aten.size.int:
target = torch.ops.aten.sym_size.int
args, kwargs = self.get_args_kwargs(node, target._schema)
fx_node = self.fx_graph.call_function(target, args, kwargs)
# TODO: covnert sourceRange() into stack_trace
# fx_node.meta["stack_trace"] = node.sourceRange()
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
def convert_prim_ListConstruct(self, node: torch._C.Node):
output_list = []
for inp in node.inputs():
output_list.append(self.get_fx_value(inp))
output_name = node.output().debugName()
self.name_to_node[output_name] = output_list
def convert_prim_DictConstruct(self, node: torch._C.Node):
output_dict = {}
k, v = None, None
for i, inp in enumerate(node.inputs()):
# We assume key value are stored in pair in the DictConstruct.
# The first element is the key and the following is the value.
if i % 2 == 0:
k = self.get_fx_value(inp)
else:
v = self.get_fx_value(inp)
assert (
k is not None and v is not None
), "DictConstruct has an empty key value pair."
output_dict[k] = v
k, v = None, None
assert (
k is None and v is None
), "DictConstruct has an odd number of elements (violating our assumption)."
output_name = node.output().debugName()
self.name_to_node[output_name] = output_dict
def convert_aten_Int(self, node: torch._C.Node):
# converts aten::Int as aten._to_copy + aten::_local_scalar_dense
target = torch.ops.aten._to_copy.default
args = tuple(self.get_fx_value(input) for input in node.inputs())
to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32})
fx_node = self.fx_graph.call_function(
torch.ops.aten._local_scalar_dense.default, (to_copy_node,)
)
# TODO: covnert sourceRange() into stack_trace
# fx_node.meta["stack_trace"] = node.sourceRange()
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
def convert_prim_NumToTensor(self, node: torch._C.Node):
# converts prim::NumToTensor as aten.scalar_tensor
target = torch.ops.aten.scalar_tensor
args = tuple(self.get_fx_value(input) for input in node.inputs())
fx_node = self.fx_graph.call_function(target, args)
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
def convert_prim_CreateObject(self, node: torch._C.Node):
output_name = node.output().debugName()
self.attribute_map[output_name] = ""
def convert_aten__convolution(self, node: torch._C.Node):
# converts aten::_convolution as aten.convolution, since aten::_convolution
# doesn't have a meta function
target = torch.ops.aten.convolution.default
args, kwargs = self.get_args_kwargs(node, target._schema)
fx_node = self.fx_graph.call_function(target, args, kwargs)
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
def convert_aten_div(self, node: torch._C.Node):
target = get_op_overload(node)
schema = target._schema
args, kwargs = self.get_args_kwargs(node, schema)
# converts aten::div.Tensor_mode(x, tensor_constant)
# as aten.div.Scalar_mode(x, tensor_constant.item())
if schema.overload_name == "Tensor_mode":
arg1_name = args[1].name
if arg1_name in self.tensor_constants:
tensor_constant = self.tensor_constants[arg1_name]
if tensor_constant.numel() == 1:
updated_args = list(args)
updated_args[1] = self.tensor_constants[arg1_name].item()
fx_node = self.fx_graph.call_function(
torch.ops.aten.div.Scalar_mode,
tuple(updated_args),
kwargs,
)
# TODO: covnert sourceRange() into stack_trace
# fx_node.meta["stack_trace"] = node.sourceRange()
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
return
self.convert_aten_op(node)
def convert_node(self, node: torch._C.Node):
node_kind = node.kind()
if node_kind == "prim::CreateObject":
self.convert_prim_CreateObject(node)
elif node_kind == "prim::Constant":
self.convert_prim_Constant(node)
elif node_kind == "prim::GetAttr":
self.convert_prim_GetAttr(node)
elif node_kind == "prim::NumToTensor":
self.convert_prim_NumToTensor(node)
elif node_kind in {"prim::ListConstruct", "prim::TupleConstruct"}:
# Tuple is just a non-mutable List, so we can handle them together.
self.convert_prim_ListConstruct(node)
elif node_kind == "prim::DictConstruct":
self.convert_prim_DictConstruct(node)
# elif node_kind == "aten::Int":
# convert_aten_Int(node)
elif node_kind == "aten::_convolution":
self.convert_aten__convolution(node)
elif node_kind == "aten::div":
self.convert_aten_div(node)
elif node_kind.startswith("aten::"):
self.convert_aten_op(node)
else:
raise ValueError(f"Unsupported node kind: {node_kind}")
def convert_graph_outputs(self):
args = []
for graph_output in self.ts_graph.outputs():
output_name = graph_output.debugName()
if output_name in self.name_to_node:
args.append(self.name_to_node[output_name])
else:
raise ValueError(f"Output {output_name} not found")
self.output_specs.append(
OutputSpec(
OutputKind.USER_OUTPUT,
arg=TensorArgument(name=output_name),
target=output_name,
)
)
self.fx_graph.output(
args[0]
) # Get rid of an extra list wrapped around final output.
def retrace_as_exported_program(self, gm: torch.fx.GraphModule):
# TODO: adjust input orders to match GraphSignature convention
inputs = [*self.sample_args, *self.params, *self.tensor_constants.values()]
ep = torch.export._trace._export(
gm,
tuple(inputs),
strict=False,
pre_dispatch=True,
)
return ep