# Owner(s): ["module: onnx"] from __future__ import annotations import contextlib import copy import dataclasses import io import os import unittest import warnings from typing import ( Any, Callable, Collection, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union, ) import numpy as np import onnxruntime import pytest import pytorch_test_common import torch from torch.onnx import _constants, verification from torch.onnx._internal import _beartype from torch.testing._internal.opinfo import core as opinfo_core from torch.types import Number _NumericType = Union[Number, torch.Tensor, np.ndarray] _ModelType = Union[torch.nn.Module, Callable] _InputArgsType = Optional[ Union[torch.Tensor, int, float, bool, Sequence[Any], Mapping[str, Any]] ] _OutputsType = Sequence[_NumericType] onnx_model_dir = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.pardir, "repos", "onnx", "onnx", "backend", "test", "data", ) pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted") pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator") def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs): options = verification.VerificationOptions() kwargs["opset_version"] = test_suite.opset_version kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs if hasattr(test_suite, "check_shape"): options.check_shape = test_suite.check_shape if hasattr(test_suite, "check_dtype"): options.check_dtype = test_suite.check_dtype names = {f.name for f in dataclasses.fields(options)} keywords_to_pop = [] for k, v in kwargs.items(): if k in names: setattr(options, k, v) keywords_to_pop.append(k) for k in keywords_to_pop: kwargs.pop(k) return verification.verify(*args, options=options, **kwargs) def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]): """Combine class name with the parameterized arguments. This function is passed to `parameterized.parameterized_class` as the `class_name_func` argument. """ suffix = "_".join(f"{k}_{v}" for k, v in input_dicts.items()) return f"{cls.__name__}_{suffix}" class _TestONNXRuntime(pytorch_test_common.ExportTestCase): opset_version = _constants.ONNX_DEFAULT_OPSET keep_initializers_as_inputs = True # For IR version 3 type export. is_script = False check_shape = True check_dtype = True def setUp(self): super().setUp() onnxruntime.set_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0" self.is_script_test_enabled = True # The exported ONNX model may have less inputs than the pytorch model because of const folding. # This mostly happens in unit test, where we widely use torch.size or torch.shape. # So the output is only dependent on the input shape, not value. # remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model. def run_test( self, model, input_args, input_kwargs=None, rtol=1e-3, atol=1e-7, do_constant_folding=True, dynamic_axes=None, additional_test_inputs=None, input_names=None, output_names=None, fixed_batch_size=False, training=torch.onnx.TrainingMode.EVAL, remained_onnx_input_idx=None, verbose=False, ): def _run_test(m, remained_onnx_input_idx, flatten=True, ignore_none=True): return run_model_test( self, m, input_args=input_args, input_kwargs=input_kwargs, rtol=rtol, atol=atol, do_constant_folding=do_constant_folding, dynamic_axes=dynamic_axes, additional_test_inputs=additional_test_inputs, input_names=input_names, output_names=output_names, fixed_batch_size=fixed_batch_size, training=training, remained_onnx_input_idx=remained_onnx_input_idx, flatten=flatten, ignore_none=ignore_none, verbose=verbose, ) if isinstance(remained_onnx_input_idx, dict): scripting_remained_onnx_input_idx = remained_onnx_input_idx["scripting"] tracing_remained_onnx_input_idx = remained_onnx_input_idx["tracing"] else: scripting_remained_onnx_input_idx = remained_onnx_input_idx tracing_remained_onnx_input_idx = remained_onnx_input_idx is_model_script = isinstance( model, (torch.jit.ScriptModule, torch.jit.ScriptFunction) ) if self.is_script_test_enabled and self.is_script: script_model = model if is_model_script else torch.jit.script(model) _run_test( script_model, scripting_remained_onnx_input_idx, flatten=False, ignore_none=False, ) if not is_model_script and not self.is_script: _run_test(model, tracing_remained_onnx_input_idx) @_beartype.beartype def run_test_with_fx_to_onnx_exporter_and_onnx_runtime( self, model: _ModelType, input_args: Sequence[_InputArgsType], input_kwargs: Optional[Mapping[str, _InputArgsType]] = None, rtol: Optional[float] = 1e-3, atol: Optional[float] = 1e-7, opset_version: int = 18, has_mutation: bool = False, verbose: bool = False, additional_test_inputs: Optional[ List[ Union[ Tuple[Sequence[_InputArgsType], Mapping[str, _InputArgsType]], Tuple[Sequence[_InputArgsType]], ] ] ] = None, ): """Compare the results of PyTorch model with exported ONNX model Args: model (_ModelType): PyTorch model input_args (Sequence[_InputArgsType]): torch input arguments input_kwargs (Mapping[str, _InputArgsType]): torch input kwargs rtol (float, optional): relative tolerance. Defaults to 1e-3. atol (float, optional): absolute tolerance. Defaults to 1e-7. opset_version (int, optional): ONNX opset version. Defaults to 18. has_mutation (bool, optional): Whether the model mutates its input or state. `mutation` as `True` incurs extra overhead of cloning the inputs and model. Defaults to False. verbose (bool, optional): Whether to save diagnostics as Sarif log and print verbose information. Defaults to False. additional_test_inputs: Test the models with another dataset input, which is designed for dynamic axes testing. Defaults to None. It's a list of different input sets in tuples. Inside tuple, the first element is a tuple of args, and the second element is a dict of kwargs. Remember to put comma even if the following element is not provided. For example, additional_test_inputs = [((args1, args2), {"kwargs":1}), ((args1,),), ((), {"kwargs":1})] """ # avoid mutable data structure if input_kwargs is None: input_kwargs = {} if has_mutation: ref_model = _try_clone_model(model) ref_input_args, ref_input_kwargs = _try_clone_inputs( input_args, input_kwargs ) else: ref_model = model ref_input_args = input_args ref_input_kwargs = input_kwargs # Feed args and kwargs into exporter. # Note that exporter should flatten kwargs into positional args the exported model; # since ONNX doesn't represent kwargs. export_output = torch.onnx.dynamo_export( ref_model, *ref_input_args, **ref_input_kwargs, export_options=torch.onnx.ExportOptions( opset_version=opset_version, op_level_debug=self.op_level_debug, dynamic_shapes=self.dynamic_shapes, ), ) if verbose: export_output.diagnostic_context.dump( f"test_report_{self._testMethodName}" f"_op_level_debug_{self.op_level_debug}" f"_dynamic_axes_{self.dynamic_shapes}" ".sarif", compress=False, ) _compare_pytorch_onnx_with_ort( export_output, model, input_args, input_kwargs, atol, rtol, has_mutation=has_mutation, ) # This confirms the exported mode accepts different input shapes # when dynamic shape is enabled. if additional_test_inputs and self.dynamic_shapes: for another_input in additional_test_inputs: if len(another_input) > 2: raise ValueError( f"test_inputs should only have tuple args and dictionary kwargs. But receives: {len(another_input)}" ) additional_input_args = another_input[0] additional_input_kwargs = ( another_input[1] if len(another_input) == 2 and another_input[1] is not None else {} ) _compare_pytorch_onnx_with_ort( export_output, model, additional_input_args, additional_input_kwargs, atol, rtol, has_mutation=has_mutation, ) @_beartype.beartype def run_ort( onnx_model: Union[str, torch.onnx.ExportOutput], pytorch_inputs: Sequence[_InputArgsType], ) -> _OutputsType: """Run ORT on the given ONNX model and inputs Used in test_fx_to_onnx_with_onnxruntime.py Args: onnx_model (Union[str, torch.onnx.ExportOutput]): Converter ONNX model pytorch_inputs (Sequence[_InputArgsType]): The given torch inputs Raises: AssertionError: ONNX and PyTorch should have the same input sizes Returns: _OutputsType: ONNX model predictions """ if isinstance(onnx_model, torch.onnx.ExportOutput): buffer = io.BytesIO() onnx_model.save(buffer) ort_model = buffer.getvalue() else: ort_model = onnx_model session = onnxruntime.InferenceSession( ort_model, providers=["CPUExecutionProvider"] ) input_names = [ort_input.name for ort_input in session.get_inputs()] if len(input_names) != len(pytorch_inputs): raise AssertionError( f"Expected {len(input_names)} inputs, got {len(pytorch_inputs)}" ) return session.run( None, {k: v.cpu().numpy() for k, v in zip(input_names, pytorch_inputs)} ) @_beartype.beartype def _try_clone_model(model: _ModelType) -> _ModelType: """Used for preserving original model in case forward mutates model states.""" try: return copy.deepcopy(model) except Exception: warnings.warn( "Failed to clone model. Model state might be mutated during verification." ) return model @_beartype.beartype def _try_clone_inputs(input_args, input_kwargs): ref_input_args = copy.deepcopy(input_args) ref_input_kwargs = copy.deepcopy(input_kwargs) return ref_input_args, ref_input_kwargs @_beartype.beartype def _compare_pytorch_onnx_with_ort( export_output: torch.onnx.ExportOutput, model: _ModelType, input_args: Sequence[_InputArgsType], input_kwargs: Mapping[str, _InputArgsType], atol: Optional[float] = None, rtol: Optional[float] = None, has_mutation: bool = False, ): if has_mutation: ref_model = _try_clone_model(model) ref_input_args, ref_input_kwargs = _try_clone_inputs(input_args, input_kwargs) else: ref_model = model ref_input_args = input_args ref_input_kwargs = input_kwargs # Format original model inputs into the format expected by exported ONNX model. onnx_format_args = export_output.adapt_torch_inputs_to_onnx( *input_args, **input_kwargs ) ref_outputs = export_output.adapt_torch_outputs_to_onnx( ref_model(*ref_input_args, **ref_input_kwargs) ) ort_outputs = run_ort(export_output, onnx_format_args) if len(ref_outputs) != len(ort_outputs): raise AssertionError( f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}" ) for ref_output, ort_output in zip(ref_outputs, ort_outputs): torch.testing.assert_close( ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol ) # The min onnx opset version to test for MIN_ONNX_OPSET_VERSION = 9 # The max onnx opset version to test for MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1) # TODO(titaiwang): Change this when more versions are supported # The min onnx opset version to test for FX_MIN_ONNX_OPSET_VERSION = 18 # The max onnx opset version to test for FX_MAX_ONNX_OPSET_VERSION = 18 FX_TESTED_OPSETS = range(FX_MIN_ONNX_OPSET_VERSION, FX_MAX_ONNX_OPSET_VERSION + 1) BOOL_TYPES = (torch.bool,) INT_TYPES = ( torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, ) QINT_TYPES = ( torch.qint8, torch.quint8, ) FLOAT_TYPES = ( torch.float16, torch.float32, torch.float64, ) COMPLEX_TYPES = ( torch.complex32, torch.complex64, torch.complex128, ) TESTED_DTYPES = ( # Boolean torch.bool, # Integers *INT_TYPES, # Floating types *FLOAT_TYPES, ) @dataclasses.dataclass class DecorateMeta: """Information about a test case to skip or xfail. Adapted from functorch: functorch/test/common_utils.py Attributes: op_name: The name of the operator. variant_name: The name of the OpInfo variant. decorator: The decorator to apply to the test case. opsets: The opsets to apply the decorator to. dtypes: The dtypes to apply the decorator to. reason: The reason for skipping. test_behavior: The behavior of the test case. [skip or xfail] matcher: The matcher to apply to the test case. enabled_if: Whether to enable test behavior. Usually used on onnx/ort version control """ op_name: str variant_name: str decorator: Callable opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] dtypes: Optional[Collection[torch.dtype]] reason: str test_behavior: str matcher: Optional[Callable[[Any], bool]] = None enabled_if: bool = True def contains_opset(self, opset: int) -> bool: if self.opsets is None: return True return any( opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset) for opset_spec in self.opsets ) def xfail( op_name: str, variant_name: str = "", *, reason: str, opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, dtypes: Optional[Collection[torch.dtype]] = None, matcher: Optional[Callable[[Any], bool]] = None, enabled_if: bool = True, ): """Expects a OpInfo test to fail. Args: op_name: The name of the operator. variant_name: The name of the variant. opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)] dtypes: The dtypes to expect the failure. reason: The reason for the failure. matcher: A function that matches the test sample input. It is used only when xfail is in the SKIP_XFAIL_SUBTESTS list. enabled_if: Whether to enable xfail. Usually used on onnx/ort version control """ return DecorateMeta( op_name=op_name, variant_name=variant_name, decorator=unittest.expectedFailure, opsets=opsets, dtypes=dtypes, enabled_if=enabled_if, matcher=matcher, reason=reason, test_behavior="xfail", ) def skip( op_name: str, variant_name: str = "", *, reason: str, opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, dtypes: Optional[Collection[torch.dtype]] = None, matcher: Optional[Callable[[Any], Any]] = None, enabled_if: bool = True, ): """Skips a test case in OpInfo that we don't care about. Likely because ONNX does not support the use case or it is by design. Args: op_name: The name of the operator. variant_name: The name of the variant. opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)] dtypes: The dtypes to expect the failure. reason: The reason for the failure. matcher: A function that matches the test sample input. It is used only when skip is in the SKIP_XFAIL_SUBTESTS list. enabled_if: Whether to enable skip. Usually used on onnx/ort version control """ return DecorateMeta( op_name=op_name, variant_name=variant_name, decorator=unittest.skip(f"Skip: {reason}"), opsets=opsets, dtypes=dtypes, reason=reason, matcher=matcher, enabled_if=enabled_if, test_behavior="skip", ) def add_decorate_info( all_opinfos: Sequence[opinfo_core.OpInfo], test_class_name: str, base_test_name: str, opset: int, skip_or_xfails: Iterable[DecorateMeta], ): """Decorates OpInfo tests with decorators based on the skip_or_xfails list. Args: all_opinfos: All OpInfos. test_class_name: The name of the test class. base_test_name: The name of the test method. opset: The opset to decorate for. skip_or_xfails: DecorateMeta's. """ ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos} for decorate_meta in skip_or_xfails: if not decorate_meta.contains_opset(opset): # Skip does not apply to this opset continue opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name)) assert ( opinfo is not None ), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?" decorators = list(opinfo.decorators) new_decorator = opinfo_core.DecorateInfo( decorate_meta.decorator, test_class_name, base_test_name, dtypes=decorate_meta.dtypes, active_if=decorate_meta.enabled_if, ) decorators.append(new_decorator) opinfo.decorators = tuple(decorators) # This decorator doesn't modify fn in any way def wrapped(fn): return fn return wrapped def opsets_before(opset: int) -> Callable[[int], bool]: """Returns a comparison function that decides if the given opset is before the specified.""" def compare(other_opset: int): return other_opset < opset return compare def opsets_after(opset: int) -> Callable[[int], bool]: """Returns a comparison function that decides if the given opset is after the specified.""" def compare(other_opset: int): return other_opset > opset return compare def reason_onnx_runtime_does_not_support( operator: str, dtypes: Optional[Sequence[str]] = None ) -> str: """Formats the reason: ONNX Runtime doesn't support the given dtypes.""" return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime" def reason_onnx_does_not_support( operator: str, dtypes: Optional[Sequence[str]] = None ) -> str: """Formats the reason: ONNX doesn't support the given dtypes.""" return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec" def reason_jit_tracer_error(info: str) -> str: """Formats the reason: JIT tracer errors.""" return f"JIT tracer error on {info}" def reason_flaky() -> str: """Formats the reason: test is flaky.""" return "flaky test" @contextlib.contextmanager def normal_xfail_skip_test_behaviors( test_behavior: Optional[str] = None, reason: Optional[str] = None ): """This context manager is used to handle the different behaviors of xfail and skip. Args: test_behavior (optional[str]): From DecorateMeta name, can be 'skip', 'xfail', or None. reason (optional[str]): The reason for the failure or skip. Raises: e: Any exception raised by the test case if it's not an expected failure. """ # We need to skip as soon as possible, as SegFault might also be a case. if test_behavior == "skip": pytest.skip(reason=reason) try: yield # We could use `except (AssertionError, RuntimeError, ...) as e:`, but it needs # to go over all test cases to find the right exception type. except Exception as e: # pylint: disable=broad-exception-caught if test_behavior is None: raise e if test_behavior == "xfail": pytest.xfail(reason=reason) else: if test_behavior == "xfail": pytest.fail("Test unexpectedly passed")