# Owner(s): ["module: onnx"] """Simple API tests for the ONNX exporter.""" from __future__ import annotations import io import logging import os from onnxscript import BOOL, FLOAT, opset18 as op import torch from torch.onnx._internal.exporter import _testing as onnx_testing from torch.testing._internal import common_utils class SampleModel(torch.nn.Module): def forward(self, x): y = x + 1 z = y.relu() return (y, z) class SampleModelTwoInputs(torch.nn.Module): def forward(self, x, b): y = x + b z = y.relu() return (y, z) class SampleModelReduction(torch.nn.Module): def forward(self, x): return x.sum() class SampleModelForDynamicShapes(torch.nn.Module): def forward(self, x, b): return x.relu(), b.sigmoid() class NestedModelForDynamicShapes(torch.nn.Module): def forward( self, x: torch.Tensor, ys: list[torch.Tensor], zs: dict[str, torch.Tensor], c: torch.Tensor, ): y = ys[0] + ys[1] + zs["a"] + zs["b"] w = 5 if x.shape[0] < 3 and c.shape[0] != 4: return x + w, x + y, c else: return x - w, x - y, c class SampleModelForDimOne(torch.nn.Module): def forward(self, x, y, z): return torch.cat((x, y), axis=1) + z class TestExportAPIDynamo(common_utils.TestCase): """Tests for the ONNX exporter API when dynamo=True.""" def assert_export( self, *args, strategy: str | None = "TorchExportNonStrictStrategy", **kwargs ): onnx_program = torch.onnx.export( *args, **kwargs, dynamo=True, fallback=False, verbose=False ) assert onnx_program is not None onnx_testing.assert_onnx_program(onnx_program, strategy=strategy) return onnx_program def test_args_normalization_with_no_kwargs(self): self.assert_export( SampleModelTwoInputs(), (torch.randn(1, 1, 2), torch.randn(1, 1, 2)), ) def test_lower_opset_support(self): # First test that opset 18 (torchlib opset works) onnx_program = self.assert_export( SampleModelReduction(), (torch.randn(1, 1, 2),), opset_version=18 ) self.assertEqual(onnx_program.model.opset_imports[""], 18) onnx_program = self.assert_export( SampleModelReduction(), (torch.randn(1, 1, 2),), opset_version=16 ) self.assertEqual(onnx_program.model.opset_imports[""], 16) def test_symbolic_argument_user_input_is_supported_by_report_and_call(self): class constant_plus_tensor_inputs(torch.nn.Module): def forward(self, a, x): return a + torch.tensor(1) + x # Capture log output log_capture = io.StringIO() log_handler = logging.StreamHandler(log_capture) log_handler.setLevel(logging.ERROR) # Get the logger used in _core.py logger = logging.getLogger("torch.onnx._internal.exporter._core") original_level = logger.level logger.addHandler(log_handler) logger.setLevel(logging.ERROR) try: with common_utils.TemporaryDirectoryName() as temp_dir: self.assert_export( constant_plus_tensor_inputs(), ( 1, torch.ones(2), ), dynamic_shapes=( torch.export.Dim.DYNAMIC, {0: torch.export.Dim.DYNAMIC}, ), report=True, artifacts_dir=temp_dir, ) # Check if the expected error was logged log_output = log_capture.getvalue() self.assertNotIn("Failed to save report due to an error", log_output) self.assertNotIn("KeyError: 'tensor_meta'", log_output) # Note: We don't call assert_onnx_program here because it will fail # due to the input name mismatch issue mentioned in your error finally: # Clean up logging logger.removeHandler(log_handler) logger.setLevel(original_level) def test_constant_argument_user_input_is_omitted_in_onnx_graph(self): class constant_plus_tensor_inputs(torch.nn.Module): def forward(self, a, x): return a + torch.tensor(1) + x onnx_program = torch.onnx.export( constant_plus_tensor_inputs(), ( 1, torch.ones(2), ), dynamic_shapes=( None, {0: torch.export.Dim.DYNAMIC}, ), dynamo=True, ) self.assertEqual(len(onnx_program.model.graph.inputs), 1) def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self): self.assert_export( SampleModelForDynamicShapes(), (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), dynamic_axes={ "x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"}, "b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"}, }, ) def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self): self.assert_export( SampleModelForDynamicShapes(), (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), dynamic_axes={ "x": [0, 1, 2], "b": [0, 1, 2], }, ) def test_dynamic_axes_supports_partial_dynamic_shapes(self): self.assert_export( SampleModelForDynamicShapes(), (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), input_names=["x", "b"], dynamic_axes={ "b": [0, 1, 2], }, ) def test_dynamic_axes_supports_output_names(self): self.assert_export( SampleModelForDynamicShapes(), (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), input_names=["x", "b"], dynamic_axes={ "b": [0, 1, 2], }, ) self.assert_export( SampleModelForDynamicShapes(), ( torch.randn(2, 2, 3), torch.randn(2, 2, 3), ), input_names=["x", "b"], output_names=["x_out", "b_out"], dynamic_axes={"b": [0, 1, 2], "b_out": [0, 1, 2]}, ) def test_from_dynamic_axes_to_dynamic_shapes_deprecation_warning(self): with self.assertWarnsRegex( DeprecationWarning, "from_dynamic_axes_to_dynamic_shapes is deprecated and will be removed in a future release. " "This function converts 'dynamic_axes' format \\(including custom axis names\\) to 'dynamic_shapes' format. " "Instead of relying on this conversion, provide 'dynamic_shapes' directly with custom names.", ): self.assert_export( SampleModelForDynamicShapes(), (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), dynamic_axes={ "x": [0, 1, 2], "b": [0, 1, 2], }, ) def test_from_dynamic_axes_to_dynamic_shapes_keeps_custom_axis_names(self): model = SampleModelForDynamicShapes() input = ( torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}, ) dynamic_axes = { "x": {0: "customx_x_0", 1: "customx_x_1", 2: "customx_x_2"}, "b": {0: "customb_b_0", 1: "customb_b_1", 2: "customb_b_2"}, "x_out": {0: "customx_out_x_0", 1: "customx_out_x_1", 2: "customx_out_x_2"}, "b_out": {0: "customb_out_b_0", 1: "customb_out_b_1", 2: "customb_out_b_2"}, } onnx_program = torch.onnx.export( model, input, dynamic_axes=dynamic_axes, input_names=["x", "b"], output_names=["x_out", "b_out"], dynamo=True, ) # Check whether the dynamic dimension names are preserved self.assertIs(onnx_program.model.graph.inputs[0].shape[0].value, "customx_x_0") self.assertIs(onnx_program.model.graph.inputs[0].shape[1].value, "customx_x_1") self.assertIs(onnx_program.model.graph.inputs[0].shape[2].value, "customx_x_2") self.assertIs(onnx_program.model.graph.inputs[1].shape[0].value, "customb_b_0") self.assertIs(onnx_program.model.graph.inputs[1].shape[1].value, "customb_b_1") self.assertIs(onnx_program.model.graph.inputs[1].shape[2].value, "customb_b_2") def test_saved_f_exists_after_export(self): with common_utils.TemporaryFileName(suffix=".onnx") as path: _ = torch.onnx.export( SampleModel(), (torch.randn(1, 1, 2),), path, dynamo=True ) self.assertTrue(os.path.exists(path)) def test_dynamic_shapes_with_fully_specified_axes(self): ep = torch.export.export( SampleModelForDynamicShapes(), ( torch.randn(2, 2, 3), torch.randn(2, 2, 3), ), dynamic_shapes={ "x": { 0: torch.export.Dim("customx_dim_0"), 1: torch.export.Dim("customx_dim_1"), 2: torch.export.Dim("customx_dim_2"), }, "b": { 0: torch.export.Dim("customb_dim_0"), 1: torch.export.Dim("customb_dim_1"), 2: torch.export.Dim("customb_dim_2"), }, }, strict=True, ) self.assert_export(ep, strategy=None) def test_partial_dynamic_shapes(self): self.assert_export( SampleModelForDynamicShapes(), ( torch.randn(2, 2, 3), torch.randn(2, 2, 3), ), dynamic_shapes={ "x": None, "b": { 0: torch.export.Dim("customb_dim_0"), 1: torch.export.Dim("customb_dim_1"), 2: torch.export.Dim("customb_dim_2"), }, }, ) def test_dynamic_shapes_supports_nested_input_model_with_input_names_assigned(self): # kwargs can still be renamed as long as it's in order input_names = ["input_x", "input_y", "input_z", "d", "e", "f"] dynamic_axes = { "input_x": {0: "dim"}, "input_y": {0: "dim"}, "input_z": {0: "dim"}, "d": {0: "dim"}, "e": {0: "dim"}, } model = NestedModelForDynamicShapes() input = ( torch.ones(5), [torch.zeros(5), torch.ones(5)], {"a": torch.zeros(5), "b": torch.ones(5)}, torch.ones(4), ) self.assert_export( model, input, dynamic_axes=dynamic_axes, input_names=input_names ) # Check whether inputs are dynamically shaped onnx_program = torch.onnx.export( model, input, dynamic_axes=dynamic_axes, input_names=input_names, dynamo=True, ) self.assertTrue( all( [ input.type.tensor_type.shape.dim[0].dim_param for input in onnx_program.model_proto.graph.input ][:-1] ) ) def test_upgraded_torchlib_impl(self): class GeluModel(torch.nn.Module): def forward(self, input): # Use GELU activation function return torch.nn.functional.gelu(input, approximate="tanh") input = (torch.randn(1, 3, 4, 4),) onnx_program_op18 = torch.onnx.export( GeluModel(), input, opset_version=18, dynamo=True, ) all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph] self.assertIn("Tanh", all_nodes_op18) self.assertNotIn("Gelu", all_nodes_op18) onnx_program_op20 = torch.onnx.export( GeluModel(), input, opset_version=20, dynamo=True, ) all_nodes_op20 = [n.op_type for n in onnx_program_op20.model.graph] self.assertIn("Gelu", all_nodes_op20) def test_refine_dynamic_shapes_with_onnx_export(self): # NOTE: From test/export/test_export.py # refine lower, upper bound class TestRefineDynamicShapeModel(torch.nn.Module): def forward(self, x, y): if x.shape[0] >= 6 and y.shape[0] <= 16: return x * 2.0, y + 1 inps = (torch.randn(16), torch.randn(12)) dynamic_shapes = { "x": (torch.export.Dim("dx"),), "y": (torch.export.Dim("dy"),), } self.assert_export( TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes ) def test_zero_output_aten_node(self): class Model(torch.nn.Module): def forward(self, x): torch.ops.aten._assert_async.msg(torch.tensor(True), "assertion failed") return x + x input = torch.randn(2) self.assert_export(Model(), (input)) def test_export_successful_when_dynamic_dimension_is_one(self): self.assert_export( SampleModelForDimOne(), (torch.randn(1, 3), torch.randn(1, 5), torch.randn(1, 8)), dynamic_shapes=( {0: "batch", 1: "sequence"}, {0: "batch", 1: "sequence"}, {0: "batch", 1: "sequence"}, ), ) def test_is_in_onnx_export(self): class Mod(torch.nn.Module): def forward(self, x): def f(x): return x.sin() if torch.onnx.is_in_onnx_export() else x.cos() return f(x) self.assertFalse(torch.onnx.is_in_onnx_export()) onnx_program = torch.onnx.export( Mod(), (torch.randn(3, 4),), dynamo=True, fallback=False, ) self.assertFalse(torch.onnx.is_in_onnx_export()) node_names = [n.op_type for n in onnx_program.model.graph] self.assertIn("Sin", node_names) def test_torchscript_exporter_raises_deprecation_warning(self): # Test that the deprecation warning is raised when using torchscript exporter with self.assertWarnsRegex( DeprecationWarning, "You are using the legacy TorchScript-based ONNX export" ): torch.onnx.export( SampleModel(), (torch.randn(1, 1, 2),), io.BytesIO(), dynamo=False ) def test_model_output_can_be_none(self): class ModelWithNoneOutput(torch.nn.Module): def forward(self, x): return x + 1, None onnx_program = torch.onnx.export( ModelWithNoneOutput(), (torch.randn(1, 1, 2),), dynamo=True, ) onnx_testing.assert_onnx_program(onnx_program) class TestCustomTranslationTable(common_utils.TestCase): def test_custom_translation_table_overrides_ops(self): from onnxscript import opset18 as op class Model(torch.nn.Module): def forward(self, x, y): return x + y def custom_add(self, other): # Replace add with sub return op.Sub(self, other) custom_translation_table = {torch.ops.aten.add.Tensor: custom_add} onnx_program = torch.onnx.export( Model(), (torch.randn(2, 2), torch.randn(2, 2)), custom_translation_table=custom_translation_table, dynamo=True, ) all_nodes = [n.op_type for n in onnx_program.model.graph] self.assertIn("Sub", all_nodes) self.assertNotIn("Add", all_nodes) def test_custom_translation_table_supports_overloading_ops(self): class Model(torch.nn.Module): def forward(self, x, y): return torch.ops.aten.logical_and.default(x, y) def custom_add_bool(self: BOOL, other: BOOL) -> BOOL: # Replace add with sub return op.Sub(self, other) def custom_add(self: FLOAT, other: FLOAT) -> FLOAT: # Replace add with mul return op.Mul(self, other) custom_translation_table = { torch.ops.aten.logical_and.default: [custom_add, custom_add_bool], } onnx_program = torch.onnx.export( Model(), (torch.tensor(1, dtype=torch.bool), torch.tensor(1, dtype=torch.bool)), custom_translation_table=custom_translation_table, dynamo=True, ) all_nodes = [n.op_type for n in onnx_program.model.graph] # The dispatcher should pick the correct overload based on the input types self.assertIn("Sub", all_nodes) self.assertNotIn("Add", all_nodes) self.assertNotIn("Mul", all_nodes) def test_custom_translation_table_supports_custom_op_as_target(self): # Define the custom op and use it in the model @torch.library.custom_op("custom::add", mutates_args=()) def custom_add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a + b @custom_add.register_fake def _(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return torch.empty_like(a) + torch.empty_like(b) class Model(torch.nn.Module): def forward(self, x, y): return custom_add(x, y) def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT: # Replace add with Sub return op.Sub(self, other) custom_translation_table = { torch.ops.custom.add.default: onnx_add, } onnx_program = torch.onnx.export( Model(), (torch.tensor(1, dtype=torch.bool), torch.tensor(1, dtype=torch.bool)), custom_translation_table=custom_translation_table, dynamo=True, ) all_nodes = [n.op_type for n in onnx_program.model.graph] self.assertIn("Sub", all_nodes) self.assertNotIn("Add", all_nodes) def test_custom_translation_table_supports_custom_op_with_its_decomp(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::foo", "(Tensor a, Tensor b) -> Tensor", tags=torch.Tag.pt2_compliant_tag, lib=lib, ) @torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib) @torch.library.register_fake("mylib::foo") def foo_impl(a, b): return a + b class M(torch.nn.Module): def forward(self, x, y): return torch.ops.mylib.foo(x, y) def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT: # Replace add with Sub return op.Sub(self, other) # With the custom op defined, we can use it in the model # and replace it with a custom translation table custom_translation_table = { torch.ops.mylib.foo.default: onnx_add, } onnx_program = torch.onnx.export( M(), (torch.ones(3, 3), torch.ones(3, 3)), custom_translation_table=custom_translation_table, dynamo=True, ) all_nodes = [n.op_type for n in onnx_program.model.graph] self.assertIn("Sub", all_nodes) self.assertNotIn("Add", all_nodes) # Without the custom op defined, it's going to be decomposed onnx_program_decomp = torch.onnx.export( M(), (torch.ones(3, 3), torch.ones(3, 3)), dynamo=True ) all_nodes_decomp = [n.op_type for n in onnx_program_decomp.model.graph] self.assertIn("Add", all_nodes_decomp) self.assertNotIn("Sub", all_nodes_decomp) def test_01_specialization_with_run_decomp_is_supported(self): # Phi3RMSNorm changes and redo shape inference after `run_decompositions` call # We ned this test to make sure everything we do on fx graph is covered by # backed_size_oblivious class Phi3RMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Phi3RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = torch.nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt( variance + self.variance_epsilon ) return self.weight * hidden_states.to(input_dtype) op = torch.onnx.export( Phi3RMSNorm(256).eval(), args=(), kwargs={"hidden_states": torch.rand((1, 32, 256))}, dynamic_shapes={ "hidden_states": { 0: "batch_size", 1: "seq_len", } }, dynamo=True, ) # batch size is not fixed to 1 self.assertNotEqual(op.model.graph.outputs[0].shape[0], 1) if __name__ == "__main__": common_utils.run_tests()