mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
- Add quantization support for `interpolate`, `avgpool`, `sigmoid` and `add_relu` - Return the inputs to ListUnpack if the previous node is ListConstruct so that `ListConstruct` and `ListUnpack` are canceled and removed in the jit passes. ONNX doesn't support them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78103 Approved by: https://github.com/garymm
348 lines
12 KiB
Python
348 lines
12 KiB
Python
"""Functions to verify exported ONNX model is functionally equivalent to original PyTorch model.
|
|
|
|
ONNX Runtime is required, and is used as the ONNX backend for export verification.
|
|
"""
|
|
|
|
import contextlib
|
|
import copy
|
|
import io
|
|
import os
|
|
import tempfile
|
|
import warnings
|
|
from typing import Any, Mapping, Optional, Sequence, Tuple, Union
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.onnx.utils import unpack_quantized_tensor
|
|
|
|
_ORT_PROVIDERS = ("CPUExecutionProvider",)
|
|
|
|
|
|
def _flatten_tuples(elem):
|
|
flattened = []
|
|
for t in elem:
|
|
if isinstance(t, tuple):
|
|
flattened.extend(_flatten_tuples(t))
|
|
else:
|
|
flattened.append(t)
|
|
return flattened
|
|
|
|
|
|
def _to_numpy(elem):
|
|
if isinstance(elem, Tensor):
|
|
if elem.requires_grad:
|
|
return elem.detach().cpu().numpy()
|
|
else:
|
|
return elem.cpu().numpy()
|
|
elif isinstance(elem, (list, tuple)):
|
|
return [_to_numpy(inp) for inp in elem]
|
|
elif isinstance(elem, (bool, int, float)):
|
|
return np.array(elem)
|
|
elif isinstance(elem, dict):
|
|
flattened = []
|
|
for k in elem:
|
|
flattened += [_to_numpy(k)] + [_to_numpy(elem[k])]
|
|
return flattened
|
|
return elem
|
|
|
|
|
|
def _inline_flatten_list(inputs, res_list):
|
|
for i in inputs:
|
|
res_list.append(i) if not isinstance(
|
|
i, (list, tuple)
|
|
) else _inline_flatten_list(i, res_list)
|
|
return res_list
|
|
|
|
|
|
def _unpack_to_numpy(values):
|
|
value_unpacked = []
|
|
for value in values:
|
|
value_unpacked.extend(unpack_quantized_tensor(value))
|
|
return [_to_numpy(v) for v in value_unpacked]
|
|
|
|
|
|
def _run_ort(ort_session, inputs):
|
|
kw_inputs = {}
|
|
if inputs and isinstance(inputs[-1], dict):
|
|
kw_inputs = inputs[-1]
|
|
inputs = inputs[:-1]
|
|
inputs = _unpack_to_numpy(_flatten_tuples(inputs))
|
|
ort_inputs = {}
|
|
for input_name, input in kw_inputs.items():
|
|
ort_inputs[input_name] = _to_numpy(input)
|
|
inputs = _to_numpy(inputs)
|
|
ort_session_inputs = ort_session.get_inputs()
|
|
for i, input in enumerate(inputs):
|
|
if i == len(ort_session_inputs) or ort_session_inputs[i].name in ort_inputs:
|
|
raise ValueError(
|
|
f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}"
|
|
)
|
|
ort_inputs[ort_session_inputs[i].name] = input
|
|
ort_outs = ort_session.run(None, ort_inputs)
|
|
return _inline_flatten_list(ort_outs, [])
|
|
|
|
|
|
def _ort_session(
|
|
model: Union[str, io.BytesIO], ort_providers: Sequence[str] = _ORT_PROVIDERS
|
|
):
|
|
try:
|
|
import onnxruntime # type: ignore[import]
|
|
except ImportError:
|
|
raise ImportError("onnxruntime is required for export verification.")
|
|
|
|
if ort_providers is None:
|
|
ort_providers = _ORT_PROVIDERS
|
|
|
|
session_options = onnxruntime.SessionOptions()
|
|
# suppress ort warnings.
|
|
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
|
session_options.log_severity_level = 3
|
|
ort_session = onnxruntime.InferenceSession(
|
|
model if isinstance(model, str) else model.getvalue(),
|
|
session_options,
|
|
providers=ort_providers,
|
|
)
|
|
return ort_session
|
|
|
|
|
|
def _compare_ort_pytorch_outputs(ort_outs, pt_outs, rtol, atol):
|
|
pt_outs, _ = torch.jit._flatten(pt_outs)
|
|
pt_outs = _unpack_to_numpy(pt_outs)
|
|
|
|
assert len(pt_outs) == len(ort_outs), "number of outputs differ"
|
|
|
|
for ort_out, pt_out in zip(ort_outs, pt_outs):
|
|
np.testing.assert_allclose(ort_out, pt_out, rtol=rtol, atol=atol)
|
|
|
|
|
|
def _prepare_input_for_pytorch(args, kwargs):
|
|
"""Prepare input for PyTorch model execution.
|
|
|
|
Any future changes/formatting to the input before dispatching to the PyTorch
|
|
model should be made in this function.
|
|
|
|
Args:
|
|
args: positional arguments for PyTorch model forward method.
|
|
kwargs: keyword arguments for PyTorch model forward method.
|
|
|
|
Returns:
|
|
args: positional arguments for PyTorch model forward method.
|
|
kwargs: keyword arguments for PyTorch model forward method.
|
|
"""
|
|
if isinstance(args, (Tensor, dict)):
|
|
args = (args,)
|
|
# In-place operators will update input tensor data as well.
|
|
# Thus inputs are replicated before every forward call.
|
|
args = copy.deepcopy(args)
|
|
if kwargs:
|
|
kwargs = copy.deepcopy(kwargs)
|
|
else:
|
|
kwargs = {}
|
|
return args, kwargs
|
|
|
|
|
|
def _prepare_input_for_export(args, kwargs):
|
|
"""Prepare input for ONNX model export.
|
|
|
|
Any future changes/formatting to the input before dispatching to the
|
|
:func:`torch.onnx.export` api should be made in this function.
|
|
|
|
Args:
|
|
args: positional arguments for PyTorch model forward method.
|
|
kwargs: keyword arguments for PyTorch model forward method.
|
|
|
|
Returns:
|
|
onnx_inputs: positional arguments for ONNX model export, as `args` in
|
|
:func:`torch.onnx.export`.
|
|
"""
|
|
args, kwargs = _prepare_input_for_pytorch(args, kwargs)
|
|
if not kwargs and isinstance(args[-1], dict):
|
|
onnx_inputs = args + ({},)
|
|
elif kwargs:
|
|
onnx_inputs = args + (kwargs,)
|
|
else:
|
|
onnx_inputs = args
|
|
return onnx_inputs
|
|
|
|
|
|
def _prepare_input_for_ort(args, kwargs, remained_onnx_input_idx, flatten):
|
|
"""Prepare input for ONNX model execution in ONNX Runtime.
|
|
|
|
Any future changes/formatting to the input before dispatching to the ONNX Runtime
|
|
InferenceSession run should be made in this function.
|
|
|
|
Args:
|
|
args: positional arguments for PyTorch model forward method.
|
|
kwargs: keyword arguments for PyTorch model forward method.
|
|
|
|
Returns:
|
|
onnx_inputs: positional arguments for ONNX model execution in ONNX Runtime.
|
|
"""
|
|
onnx_inputs = _prepare_input_for_export(args, kwargs)
|
|
if flatten:
|
|
onnx_inputs, _ = torch.jit._flatten(onnx_inputs)
|
|
elif onnx_inputs and onnx_inputs[-1] == {}:
|
|
# Handle empty kwargs (normally removed by flatten).
|
|
onnx_inputs = onnx_inputs[:-1]
|
|
if remained_onnx_input_idx is not None:
|
|
return [onnx_inputs[i] for i in remained_onnx_input_idx]
|
|
else:
|
|
return onnx_inputs
|
|
|
|
|
|
def _try_clone_model(model):
|
|
"""Used for preserving original model in case forward mutates model states."""
|
|
try:
|
|
return copy.deepcopy(model)
|
|
except Exception:
|
|
warnings.warn(
|
|
"Failed to clone model. Model state might be mutated during verification."
|
|
)
|
|
return model
|
|
|
|
|
|
def _compare_ort_pytorch_model(
|
|
model,
|
|
ort_session,
|
|
input_args,
|
|
input_kwargs,
|
|
additional_test_inputs,
|
|
remained_onnx_input_idx,
|
|
flatten,
|
|
rtol,
|
|
atol,
|
|
):
|
|
"""Compare outputs from ONNX model runs with outputs from PyTorch model runs.
|
|
|
|
ONNX Runtime is used for model execution backend for ONNX model.
|
|
|
|
Raises:
|
|
AssertionError: if outputs from ONNX model and PyTorch model are not
|
|
equal up to specified precision.
|
|
"""
|
|
|
|
def compare_ort_pytorch_model_with_input(input_args, input_kwargs):
|
|
pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs)
|
|
# TODO: remove this and treat mutating model separately. See #77679
|
|
model_copy = _try_clone_model(model)
|
|
pt_outs = model_copy(*pt_args, **pt_kwargs)
|
|
|
|
ort_inputs = _prepare_input_for_ort(
|
|
input_args, input_kwargs, remained_onnx_input_idx, flatten
|
|
)
|
|
ort_outs = _run_ort(ort_session, ort_inputs)
|
|
|
|
_compare_ort_pytorch_outputs(ort_outs, pt_outs, rtol, atol)
|
|
|
|
compare_ort_pytorch_model_with_input(input_args, input_kwargs)
|
|
|
|
if additional_test_inputs:
|
|
for test_input_args in additional_test_inputs:
|
|
compare_ort_pytorch_model_with_input(test_input_args, {})
|
|
|
|
|
|
def verify(
|
|
model: Union[torch.nn.Module, torch.jit.ScriptModule],
|
|
input_args: Tuple[Any, ...],
|
|
input_kwargs: Optional[Mapping[str, Any]] = None,
|
|
do_constant_folding: bool = True,
|
|
dynamic_axes: Optional[
|
|
Mapping[str, Union[Mapping[int, str], Mapping[str, Sequence[int]]]]
|
|
] = None,
|
|
input_names: Optional[Sequence[str]] = None,
|
|
output_names: Optional[Sequence[str]] = None,
|
|
training: Optional[bool] = None,
|
|
opset_version: Optional[int] = None,
|
|
keep_initializers_as_inputs: bool = True,
|
|
verbose: bool = False,
|
|
fixed_batch_size: bool = False,
|
|
use_external_data: bool = False,
|
|
additional_test_inputs: Optional[Sequence[Tuple[Any, ...]]] = None,
|
|
remained_onnx_input_idx: Optional[Sequence[int]] = None,
|
|
flatten: bool = True,
|
|
ort_providers: Sequence[str] = _ORT_PROVIDERS,
|
|
rtol: float = 0.001,
|
|
atol: float = 1e-7,
|
|
**kwargs,
|
|
):
|
|
"""Verify model export to ONNX with ONNX Runtime.
|
|
|
|
Args:
|
|
model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`.
|
|
input_args (tuple): See :func:`torch.onnx.export`.
|
|
input_kwargs (dict): See :func:`torch.onnx.export`.
|
|
do_constant_folding (bool, optional): See :func:`torch.onnx.export`.
|
|
dynamic_axes (dict, optional): See :func:`torch.onnx.export`.
|
|
input_names (list, optional): See :func:`torch.onnx.export`.
|
|
output_names (list, optional): See :func:`torch.onnx.export`.
|
|
training (bool, optional): See :func:`torch.onnx.export`.
|
|
opset_version (int, optional): See :func:`torch.onnx.export`.
|
|
keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`.
|
|
verbose (bool, optional): See :func:`torch.onnx.export`.
|
|
fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases.
|
|
use_external_data (bool, optional): Explicitly specify whether to export the
|
|
model with external data.
|
|
additional_test_inputs (list, optional): List of tuples. Each tuple is a set of
|
|
input arguments to test. Currently only *args are supported.
|
|
remained_onnx_input_idx (list, optional): If provided, only the specified inputs
|
|
will be passed to the ONNX model. Supply a list when there are unused inputs
|
|
in the model. Since unused inputs will be removed in the exported ONNX
|
|
model, supplying all inputs will cause an error on unexpected inputs.
|
|
This parameter tells the verifier which inputs to pass into the ONNX model.
|
|
flatten (bool, optional): Default True. If True, unpack nested list/tuple/dict
|
|
inputs into a flattened list of Tensors for ONNX. Set this to False if nested
|
|
structures are to be preserved for ONNX, which is usually the case with
|
|
exporting ScriptModules.
|
|
ort_providers (sequence, optional): ONNX Runtime providers to use.
|
|
rtol (float, optional): relative tolerance in comparison between ONNX and PyTorch outputs.
|
|
atol (float, optional): absolute tolerance in comparison between ONNX and PyTorch outputs.
|
|
|
|
Raises:
|
|
AssertionError: if outputs from ONNX model and PyTorch model are not
|
|
equal up to specified precision.
|
|
"""
|
|
if training is not None and training == torch.onnx.TrainingMode.TRAINING:
|
|
model.train()
|
|
elif training is None or training == torch.onnx.TrainingMode.EVAL:
|
|
model.eval()
|
|
with torch.no_grad(), contextlib.ExitStack() as stack:
|
|
model_f: Union[str, io.BytesIO] = io.BytesIO()
|
|
if use_external_data:
|
|
tmpdirname = stack.enter_context(tempfile.TemporaryDirectory())
|
|
model_f = os.path.join(tmpdirname, "model.onnx")
|
|
|
|
inputs_for_export = _prepare_input_for_export(input_args, input_kwargs)
|
|
|
|
# TODO: remove this and treat mutating model separately. See #77679
|
|
model_copy = _try_clone_model(model)
|
|
torch.onnx._export(
|
|
model,
|
|
inputs_for_export,
|
|
model_f,
|
|
opset_version=opset_version,
|
|
do_constant_folding=do_constant_folding,
|
|
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
|
dynamic_axes=dynamic_axes,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
fixed_batch_size=fixed_batch_size,
|
|
training=training,
|
|
verbose=verbose,
|
|
)
|
|
|
|
ort_session = _ort_session(model_f, ort_providers)
|
|
|
|
_compare_ort_pytorch_model(
|
|
model_copy,
|
|
ort_session,
|
|
input_args,
|
|
input_kwargs,
|
|
additional_test_inputs,
|
|
remained_onnx_input_idx,
|
|
flatten,
|
|
rtol,
|
|
atol,
|
|
)
|