[ONNX] Remove fx_onnx_interpreter.py (#158282)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158282
Approved by: https://github.com/Skylion007, https://github.com/justinchuby
ghstack dependencies: #158258, #158262
This commit is contained in:
Ti-Tai Wang 2025-07-15 17:32:59 +00:00 committed by PyTorch MergeBot
parent cc0faeb80f
commit e4c17d5e1c

View File

@ -1,718 +0,0 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import inspect
import operator
from typing import Callable, TYPE_CHECKING
import onnxscript
from onnxscript.function_libs.torch_lib import (
graph_building as onnxscript_graph_building,
)
import torch
import torch.fx
from torch.onnx import _type_utils as jit_type_utils
from torch.onnx._internal.fx import (
_pass,
onnxfunction_dispatcher,
type_utils as fx_type_utils,
)
from torch.utils import _pytree
if TYPE_CHECKING:
from collections.abc import Sequence
def _fx_node_to_onnx_message_formatter(
fn: Callable,
self,
node: torch.fx.Node,
*args,
**kwargs,
) -> str:
return f"FX Node: {node.op}:{node.target}[name={node.name}]. "
def _fx_graph_to_onnx_message_formatter(
fn: Callable,
self,
fx_graph_module: torch.fx.GraphModule,
*args,
**kwargs,
) -> str:
return f"FX Graph: {fx_graph_module._get_name()}. "
def _retrieve_or_adapt_input_to_graph_set(
fx_node_arg: fx_type_utils.Argument,
fx_name_to_onnxscript_value: dict[
str,
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...],
],
tracer: onnxscript_graph_building.TorchScriptTracingEvaluator,
):
"""Map FX value to TorchScript value.
When creating TorchScript graph from FX graph, we need a mapping from FX variable
to TorchScript variable. This function maps FX variable, fx_node_arg, to torch.jit.Value.
"""
from onnxscript import opset18 as op
onnx_tensor = fx_node_arg
if isinstance(onnx_tensor, torch.fx.Node):
# 1. fx_node_arg is a torch.fx.Node, which means
# fx_node_arg stands for the output of that torch.fx.Node.
# 2. fx_node_arg (variable in torch.fx.Graph) is be mapped to
# torch.jit.Value, fx_name_to_onnxscript_value[fx_node_arg.name],
# in TorchScript graph.
return fx_name_to_onnxscript_value[onnx_tensor.name]
elif isinstance(onnx_tensor, (tuple, list)) and any(
isinstance(node, torch.fx.Node)
and fx_type_utils.is_torch_symbolic_type(node.meta.get("val"))
for node in onnx_tensor
):
# This intends to handle dynamic axes. for example, if the input size of op.Expand
# is dynamic, each dimension would be variable (i.e., sym variable in Pytorch
# FX graph. Note that sym variable is mapped to tensor in ONNX Script world)
# calculated by other operators.
sequence_mixed_elements: list[
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...]
| list[int]
] = []
# onnx_tensor contains a list of scalars which could be one of
# - tensor with empty shape,
# - tensor with tensor with shape (1,),
# - torch.SymInt,
# - int
# - ...
# They should all be promoted to tensor with shape (1,)
# in order to call ONNX's Concat.
for tensor in onnx_tensor:
# Prepare `tensor` as input of ONNX's Concat.
if isinstance(
tensor, torch.fx.Node
) and fx_type_utils.is_torch_symbolic_type(tensor.meta.get("val")):
# In this case, tensor is a torch.SymInt from Dynamo's perspective.
# It might be mapped to tensor with shape () or (1,) in ONNX.
element_value = fx_name_to_onnxscript_value[tensor.name]
if isinstance(
element_value, onnxscript_graph_building.TorchScriptTensor
):
# All elements sequence_mixed_elements will be send to onnx's Concat
# as inputs. Therefore, they are required to have the same rank.
# Since tensors with rank=0 (i.e., scalar) cannot be concated, all
# scalars are promoted to tensors with shape (1,).
with onnxscript.evaluator.default_as(tracer):
element_value = op.Reshape(
element_value, # type: ignore[arg-type, type-var]
[1], # type: ignore[arg-type, type-var]
)
sequence_mixed_elements.append(element_value)
elif isinstance(tensor, int):
# NOTE: op.Concat doesn't support scalar, so we need to wrap it with
# dim, and onnx-script will promote it to tensor(int64)
sequence_mixed_elements.append([tensor])
else:
raise RuntimeError(
f"Unsupported type in sequence_mixed_elements: {type(tensor)}"
)
# Concat all the elements in the sequence.
# shapes are mapped to tensors in ONNX graph (TorchScriptGraph),
# so list of sym_ints is concatenated to a tensor before calling ONNX op.
# For example:
# inputs: [[2], [4], fx.Node(SymIntA), [1], fx.Node(SymIntB)]
# outputs: op.Concat([op.Constant(2), op.Constant(4), TorchScriptTensor(A), op.Constant(1), TorchScriptTensor(B)])
# onnx-script auto wraps python number with op.Constants,
# so we don't need to specifically process them.
with onnxscript.evaluator.default_as(tracer):
output = op.Concat(*sequence_mixed_elements, axis=0) # type: ignore[type-var]
output.dtype = torch.int64 # type: ignore[union-attr]
output.shape = [len(sequence_mixed_elements)] # type: ignore[union-attr]
return output
elif isinstance(onnx_tensor, (tuple, list)) and all(
isinstance(node, torch.fx.Node) or node is None for node in onnx_tensor
):
sequence_elements: list[
onnxscript_graph_building.TorchScriptTensor
| None
| tuple[onnxscript_graph_building.TorchScriptTensor, ...]
] = []
for tensor in onnx_tensor:
sequence_elements.append(
fx_name_to_onnxscript_value[tensor.name] if tensor is not None else None # type: ignore[index, union-attr]
)
return sequence_elements
if isinstance(onnx_tensor, torch.dtype):
onnx_tensor = int( # type: ignore[call-overload]
jit_type_utils.JitScalarType.from_dtype(onnx_tensor).onnx_type()
)
# NOTE: if device is specified in kwargs (not consumed), it's free to ignored. But
# if it's in args, we need to set it to string for dispatcher to match schema.
if isinstance(onnx_tensor, torch.device):
# torch.device is not supported by onnxscript (no op). We turn it into
# a string.
return str(onnx_tensor)
# all other cases, we do nothing.
return onnx_tensor
def filter_incompatible_and_dtype_convert_kwargs(kwargs):
"""Filter out kwargs that are not supported by onnxscript."""
filtered = {}
for key, value in kwargs.items():
if key in {
"layout",
"device",
"requires_grad",
"pin_memory",
"memory_format",
"implicit",
}:
continue
if key == "dtype":
if value is None:
# We omit if dtype is not provided, because onnxscript handles the
# default case.
continue
else:
value = int(jit_type_utils.JitScalarType.from_dtype(value).onnx_type()) # type: ignore[call-overload]
filtered[key] = value
return filtered
def _fill_tensor_shape_type(
onnxscript_values: onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...],
name: str,
expected_values: fx_type_utils.META_VALUE_TYPE
| list[fx_type_utils.META_VALUE_TYPE]
| tuple[fx_type_utils.META_VALUE_TYPE | None, ...],
):
"""Fill the meta information of onnxscript_values with that from the fx FakeTensor."""
if isinstance(expected_values, (list, tuple)) and not isinstance(
onnxscript_values, (list, tuple)
):
# ex: aten::split - in onnx_dtype: seq(tensor)
# onnxscript_values is a single tensor, but expected_values is a list of tensors.
return
flat_onnxscript_values, _ = _pytree.tree_flatten(onnxscript_values)
flat_expected_values, _ = _pytree.tree_flatten(expected_values)
for i, (onnxscript_value, expected_value) in enumerate(
zip(flat_onnxscript_values, flat_expected_values)
):
if expected_value is None:
# There is no shape/type from None.
# NOTE: according to https://github.com/pytorch/pytorch/blob/main/torch/_meta_registrations.py,
# None could be a valid value for return type, so we need to handle it.
# e.g. the function: meta__scaled_dot_product_flash() in cpu mode.
continue
elif fx_type_utils.is_torch_symbolic_type(expected_value):
# aten::sym_size output is a int, not a tensor, which stands
# for the size of one dim. We treat it as 1-D tensor.
onnxscript_value.dtype = fx_type_utils.from_sym_value_to_torch_dtype(
expected_value
)
onnxscript_value.shape = torch.Size([1])
elif isinstance(expected_value, (int, float, bool)):
onnxscript_value.dtype = fx_type_utils.from_scalar_type_to_torch_dtype(
type(expected_value)
)
onnxscript_value.shape = torch.Size([])
elif isinstance(expected_value, complex):
# From complex scalar to real representation
onnxscript_value_to_torch_dtype = (
fx_type_utils.from_scalar_type_to_torch_dtype(type(expected_value))
)
onnxscript_value.dtype = (
fx_type_utils.from_complex_to_float(onnxscript_value_to_torch_dtype)
if onnxscript_value_to_torch_dtype is not None
else None
)
onnxscript_value.shape = torch.Size([2])
elif fx_type_utils.is_torch_complex_dtype(expected_value.dtype):
# Like torch.view_as_real, we flatten complex tensors to real tensors with
# additional last dimension of 2
onnxscript_value.shape = torch.Size((*expected_value.size(), 2))
# complex64 -> float32, complex128 -> float64, etc.
onnxscript_value.dtype = fx_type_utils.from_complex_to_float(
expected_value.dtype
)
# Dispatcher needs to know the value is complex
onnxscript_value.is_complex = True
else:
# We set node output sizes to be dynamic to continue the model conversion,
# and inputs are also set to be dynamic in add_input().
onnxscript_value.shape = expected_value.size()
onnxscript_value.dtype = expected_value.dtype
# naming
if i > 0:
onnxscript_value.name = f"{name}_{i}"
else:
onnxscript_value.name = name
def _fill_in_default_kwargs(
node: torch.fx.Node,
) -> tuple[list[fx_type_utils.Argument], dict[str, fx_type_utils.Argument]]:
"""Find and Fill in the not provided kwargs with default values."""
# TODO: aten::sym_size has overload, but fx graph is using
# overloadpacket for some reasons.
# https://github.com/pytorch/pytorch/issues/97201
# We manually assigned overload for aten::sym_size.
if hasattr(node.target, "_schema"):
node_schema = node.target._schema # type: ignore[union-attr]
else:
node_schema = torch.ops.aten.sym_size.int._schema # type: ignore[union-attr]
# This function assumes the order of arguments in FX op is the
# same as the order of arguments in TorchScript op.
complete_args: list[fx_type_utils.Argument] = []
complete_kwargs: dict[str, fx_type_utils.Argument] = {}
if inspect.isbuiltin(node.target):
complete_args = list(node.args)
else:
for i, expected_arg in enumerate(node_schema.arguments):
if i < len(node.args):
complete_args.append(node.args[i])
elif expected_arg.name in node.kwargs:
complete_kwargs[expected_arg.name] = node.kwargs[expected_arg.name]
else:
# Get default from schema.
complete_kwargs[expected_arg.name] = expected_arg.default_value
return complete_args, complete_kwargs
def _wrap_fx_args_as_onnxscript_args(
complete_args: list[fx_type_utils.Argument],
complete_kwargs: dict[str, fx_type_utils.Argument],
fx_name_to_onnxscript_value: dict[
str,
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...],
],
tracer: onnxscript_graph_building.TorchScriptTracingEvaluator,
) -> tuple[
Sequence[
onnxscript_graph_building.TorchScriptTensor
| str
| int
| float
| bool
| list
| complex
| None
],
dict[str, fx_type_utils.Argument],
]:
"""Map all FX arguments of a node to arguments in TorchScript graph."""
onnxscript_args = tuple(
_retrieve_or_adapt_input_to_graph_set(arg, fx_name_to_onnxscript_value, tracer)
for arg in complete_args
)
onnxscript_kwargs = filter_incompatible_and_dtype_convert_kwargs(complete_kwargs)
return onnxscript_args, onnxscript_kwargs
class FxOnnxInterpreter:
"""Stateless class to process FX graph Nodes and translate them into their ONNX counterparts.
All FX nodes described by [FX Graph](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph) are supported.
Similarly to [FX Interpreter pattern](https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter), each FX node
must be implemented on its own method in this class.
Each operator's implementation returns either an `onnxscript.OnnxFunction` or
`onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. They can
also raise RuntimeError: If there are no overloaded functions available for the given FX node.
"""
def run_node(
self,
node,
fx_graph_module: torch.fx.GraphModule,
onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher,
onnxscript_graph: onnxscript_graph_building.TorchScriptGraph,
onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator,
fx_name_to_onnxscript_value: dict[
str,
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...],
],
):
"""Execute a single FX node to produce its ONNX counterpart.
Args:
node: The FX node to be translated.
fx_graph_module: The FX graph module containing the node.
onnxfunction_dispatcher: The dispatcher to find the best matched ONNX op.
onnxscript_graph: The ONNX graph to be populated.
onnxscript_tracer: The tracer to trace the ONNX graph.
fx_name_to_onnxscript_value: The mapping from FX node name to ONNX Script value.
Raises:
RuntimeError: When a node.op is not supported.
"""
if node.op == "placeholder":
self.placeholder(node, onnxscript_graph, fx_name_to_onnxscript_value)
elif node.op == "get_attr":
self.get_attr(
node,
onnxscript_graph,
fx_name_to_onnxscript_value,
fx_graph_module,
)
elif node.op == "call_function":
self.call_function(
node,
onnxscript_tracer,
fx_name_to_onnxscript_value,
onnxfunction_dispatcher,
fx_graph_module,
)
elif node.op == "call_method":
self.call_method(node)
elif node.op == "call_module":
self.call_module(
node,
onnxscript_graph,
fx_name_to_onnxscript_value,
onnxscript_tracer,
fx_graph_module,
onnxfunction_dispatcher,
)
elif node.op == "output":
self.output(node, onnxscript_graph, fx_name_to_onnxscript_value)
else:
raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}")
def run(
self,
fx_graph_module: torch.fx.GraphModule,
onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher,
parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph
| None = None,
) -> onnxscript_graph_building.TorchScriptGraph:
"""Analyze all FX nodes and trigger their ONNX translation.
Args:
fx_graph_module: FX graph module to be translated.
onnxfunction_dispatcher: ONNX function dispatcher.
parent_onnxscript_graph: The parent TorchScript graph. Must be provided if
`fx_graph_module` is a submodule. If not provided,
`fx_graph_module` is assumed to be the root module.
"""
if parent_onnxscript_graph is not None:
# If parent_onnxscript_graph is provided, we assume fx_graph_module is a
# submodule representing a forward call of an nn.Module.
# Compose package and version where the nn.Module is defined as domain name
# for the local function.
onnx_meta: _pass.GraphModuleOnnxMeta | None = fx_graph_module.meta.get(
"onnx"
)
if onnx_meta is None:
raise RuntimeError(
f"ONNX meta is not found in submodule {fx_graph_module._get_name()}. "
f"Only submodules produced by `Modularize` pass is supported in ONNX export."
)
onnx_domain = onnx_meta.package_info.to_onnx_domain_string()
else:
# Leave as default domain name for the root module.
onnx_domain = None
onnxscript_graph = onnxscript_graph_building.TorchScriptGraph(
parent_onnxscript_graph, domain_name=onnx_domain
)
onnxscript_tracer = onnxscript_graph_building.TorchScriptTracingEvaluator(
onnxscript_graph
)
# In the following loop, a TorchScript graph is created to
# represent the input FX graph with ONNX symbols (e.g., onnx::add).
# To connect the values to nodes in the TorchScript graph, we maintain
# fx_name_to_onnxscript_value. Basically, we want to translate
# fx_tensor_x (type: torch.fx.Node) -> fx_node_1 -> fx_tensor_y (type: torch.fx.Node)
# to
# fx_name_to_onnxscript_value[fx_tensor_x.name] -> onnx_node_1 -> fx_name_to_onnxscript_value[fx_tensor_y.name]
fx_name_to_onnxscript_value: dict[
str,
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...],
] = {}
# TODO: Fix FakeTensorMode limitation asap
# We want to pass list of ints and floats to TorchScript graph correctly
# in _export_fx_to_ts, so we must disable FakeTensorMode. Otherwise, graph may
# receive FakeTensor and results runtime error. In addition, TorchScript-based
# ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible
# with FakeTensorMode.
with torch.utils._mode_utils.no_dispatch():
for node in fx_graph_module.graph.nodes:
self.run_node(
node,
fx_graph_module,
onnxfunction_dispatcher,
onnxscript_graph,
onnxscript_tracer,
fx_name_to_onnxscript_value,
)
return onnxscript_graph
def placeholder(
self,
node: torch.fx.Node,
onnxscript_graph: onnxscript_graph_building.TorchScriptGraph,
fx_name_to_onnxscript_value: dict[
str,
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...],
],
):
# Input of graph.
# The node.meta["val"] is generated by FakeTensorProp.
# NOTE: add_input() intends to create nodes with shape/type
fake_tensor = node.meta.get("val", None)
# NOTE: During the tracing, when inputs are constants, they are represented
# by nodes with node.meta['val'] being None (nn.Module to dynamo_export)
# or nodes with node.meta['val'] being a builtin value (ExportedProgram to dynamo_export).
# Nonethless, the nodes are not consumed by others, so we don't need to
# create a TorchScriptTensor for them.
if fake_tensor is None or isinstance(fake_tensor, (int, float, bool, str)):
output = onnxscript_graph.add_input(
input_name=None,
)
elif isinstance(fake_tensor, torch.Tensor):
# NOTE: ONNX doesn't support tensor of complex64/complex128, so we
# convert them to float32/float64 with real representation.
if fx_type_utils.is_torch_complex_dtype(fake_tensor.dtype):
fake_tensor = torch.view_as_real(fake_tensor.resolve_conj())
output = onnxscript_graph.add_input(
input_name=node.name,
shape=fake_tensor.shape,
dtype=fake_tensor.dtype,
)
elif fx_type_utils.is_torch_symbolic_type(fake_tensor):
output = onnxscript_graph.add_input(
input_name=node.name,
shape=torch.Size([]),
dtype=fx_type_utils.from_sym_value_to_torch_dtype(fake_tensor),
)
else:
raise RuntimeError(
f"Unsupported type(node.meta['val']) for placeholder: {type(fake_tensor)}"
)
assert output is not None, (
f"Node creates None with target={node.target} and name={node.name}"
)
assert isinstance(output, onnxscript_graph_building.TorchScriptTensor)
assert isinstance(output, onnxscript.tensor.Tensor)
fx_name_to_onnxscript_value[node.name] = output
def call_function(
self,
node: torch.fx.Node,
onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator,
fx_name_to_onnxscript_value: dict[
str,
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...],
],
onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher,
fx_graph_module: torch.fx.GraphModule,
):
# aten ops and other stateless functions.
if node.target == operator.getitem and isinstance(
fx_name_to_onnxscript_value[node.args[0].name], # type: ignore[union-attr,index]
tuple,
):
onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index]
index = node.args[1]
value = onnx_tensor_tuple[index] # type: ignore[index]
assert value is not None, (
f"Node creates None with target={node.target} and name={node.name}"
)
assert isinstance(
value, (onnxscript_graph_building.TorchScriptTensor, tuple)
), type(value)
fx_name_to_onnxscript_value[node.name] = value
return
# Map FX inputs to ONNX inputs and fill optional inputs with default values.
# torch_args and torch_kwargs are for op-level validation
fx_args, fx_kwargs = _fill_in_default_kwargs(node)
onnx_args, onnx_kwargs = _wrap_fx_args_as_onnxscript_args(
fx_args,
fx_kwargs,
fx_name_to_onnxscript_value,
onnxscript_tracer,
)
# Dispatch to ONNX op through OpShema. The input argument dtypes are compared to
# function signature in OpSchema, and find the best matched overload.
symbolic_fn = onnxfunction_dispatcher.dispatch(
node=node,
onnx_args=onnx_args, # type: ignore[arg-type]
onnx_kwargs=onnx_kwargs,
)
with onnxscript.evaluator.default_as(onnxscript_tracer):
output: (
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...]
) = symbolic_fn(*onnx_args, **onnx_kwargs)
assert output is not None, (
f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}"
)
# Assign type and shape from fx graph.
_fill_tensor_shape_type(output, node.name, node.meta["val"])
# One fx node could produce multiple outputs (e.g., tuple of tensors); in
# that case, v is a tuple of TorchScriptTensors.
assert isinstance(
output, (onnxscript_graph_building.TorchScriptTensor, tuple)
), type(output)
fx_name_to_onnxscript_value[node.name] = output
def output(
self,
node: torch.fx.Node,
onnxscript_graph: onnxscript_graph_building.TorchScriptGraph,
fx_name_to_onnxscript_value: dict[
str,
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...],
],
):
if isinstance(node.args[0], torch.fx.Node):
onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name]
onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple)
else:
# ONNX can't represent collection types (e.g., dictionary, tuple of tuple of
# tensor, etc), we flatten the collection and register each element as output.
flat_args, _ = _pytree.tree_flatten(node.args[0])
for arg in flat_args:
assert isinstance(arg, torch.fx.Node), (
f"arg must be a torch.fx.Node, not {type(arg)}"
)
onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[arg.name]
onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple)
def call_method(self, node: torch.fx.Node):
# TODO(wechi): Support call_method.
raise RuntimeError("call_method is not supported yet.")
def call_module(
self,
node: torch.fx.Node,
parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph,
fx_name_to_onnxscript_value: dict[
str,
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...],
],
tracer: onnxscript_graph_building.TorchScriptTracingEvaluator,
root_fx_graph_module: torch.fx.GraphModule,
onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher,
) -> None:
"""Export a fx.GraphModule submodule to ONNXScript graph.
The export process specifically targets `call_module` nodes that are created by
the exporter's `Modularize` pass. Each `call_module` node has an associated fx.GraphModule
by `node.target` underneath the root fx.GraphModule. These `call_module` nodes are exported as ONNX
function nodes. The related `sub_module` is then exported as an ONNX model local function,
which is represented by another `TorchScriptGraph`. This `TorchScriptGraph` sets the current
`onnxscript_graph` as its parent.
Args:
node: The call_module node in the FX graph that represents the submodule call.
parent_onnxscript_graph: The parent ONNXScript graph to which the ONNX function and
function node belong.
fx_name_to_onnxscript_value: The mapping from FX node name to ONNXScript value.
tracer: The tracer used to trace the ONNXScript graph.
root_fx_graph_module: The root FX module.
onnxfunction_dispatcher: The dispatcher.
"""
assert isinstance(node.target, str), (
f"node.target must be a str, not {type(node.target)} for node {node}."
)
sub_module = root_fx_graph_module.get_submodule(node.target)
assert isinstance(sub_module, torch.fx.GraphModule), (
f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}."
)
sub_onnxscript_graph = self.run(
sub_module, onnxfunction_dispatcher, parent_onnxscript_graph
)
onnx_args, _ = _wrap_fx_args_as_onnxscript_args(
list(node.args), {}, fx_name_to_onnxscript_value, tracer
)
# TODO: We may want to consider other naming styles. The goal is to be stable and
# unique such that it can be easily identified in case of kernel substitution.
# Example for current style is combination of qualified module class name and
# module attribute name: `torch_nn_modules_conv_Conv2d_conv1`.
# Other naming styles such as qualified module class name made unique can also
# be considered.
unique_module_name = f"{sub_module._get_name()}_{node.target}"
outputs: (
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...]
) = parent_onnxscript_graph.add_module_call( # type: ignore[assignment]
unique_module_name, sub_onnxscript_graph, onnx_args
)
assert isinstance(
outputs, (onnxscript_graph_building.TorchScriptTensor, tuple)
), f"Unexpected outputs type {type(outputs)} for node {node}."
_fill_tensor_shape_type(outputs, node.name, node.meta["val"])
fx_name_to_onnxscript_value[node.name] = outputs
# Skip op_level_validation for call_module. Subgraph nodes are validated individually.
def get_attr(
self,
node: torch.fx.Node,
onnxscript_graph: onnxscript_graph_building.TorchScriptGraph,
fx_name_to_onnxscript_value: dict[
str,
onnxscript_graph_building.TorchScriptTensor
| tuple[onnxscript_graph_building.TorchScriptTensor, ...],
],
fx_graph_module: torch.fx.GraphModule,
):
assert isinstance(node.target, str), f"node.target {node.target} is not a str."
attr_tensor = getattr(fx_graph_module, node.target)
assert isinstance(attr_tensor, torch.Tensor), f"{attr_tensor} is not a tensor."
# Parameter/buffer name cannot contain "."
# Revert from "/" to restore namespace formatting.
input_ = onnxscript_graph.add_initializer(
name=node.target.replace("/", "."),
value=attr_tensor,
)
assert isinstance(input_, onnxscript_graph_building.TorchScriptTensor)
assert isinstance(input_, onnxscript.tensor.Tensor)
fx_name_to_onnxscript_value[node.name] = input_