[ONNX] Opt into ruff fmt (#134120)

Add ONNX directory to use ruff format.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134120
Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007
This commit is contained in:
Justin Chu 2024-08-22 22:44:03 +00:00 committed by PyTorch MergeBot
parent 25499de814
commit b319fa3fd9
60 changed files with 313 additions and 276 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: onnx"] # Owner(s): ["module: onnx"]
"""Unit tests for the internal registration wrapper module.""" """Unit tests for the internal registration wrapper module."""
from __future__ import annotations from __future__ import annotations
import operator import operator

View File

@ -22,8 +22,7 @@ if typing.TYPE_CHECKING:
class _SarifLogBuilder(Protocol): class _SarifLogBuilder(Protocol):
def sarif_log(self) -> sarif.SarifLog: def sarif_log(self) -> sarif.SarifLog: ...
...
def _assert_has_diagnostics( def _assert_has_diagnostics(
@ -344,9 +343,7 @@ class TestTorchScriptOnnxDiagnostics(common_utils.TestCase):
self.assertIn("test_diagnostics.py", frame.location.uri) self.assertIn("test_diagnostics.py", frame.location.uri)
def test_diagnostics_records_cpp_call_stack(self): def test_diagnostics_records_cpp_call_stack(self):
diagnostic = ( diagnostic = self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
)
stack = diagnostic.cpp_call_stack stack = diagnostic.cpp_call_stack
assert stack is not None # for mypy assert stack is not None # for mypy
self.assertGreater(len(stack.frames), 0) self.assertGreater(len(stack.frames), 0)
@ -368,9 +365,9 @@ class TestDiagnosticsInfra(common_utils.TestCase):
def setUp(self): def setUp(self):
self.rules = _RuleCollectionForTest() self.rules = _RuleCollectionForTest()
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
self.context: infra.DiagnosticContext[ self.context: infra.DiagnosticContext[infra.Diagnostic] = (
infra.Diagnostic stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
] = stack.enter_context(infra.DiagnosticContext("test", "1.0.0")) )
self.addCleanup(stack.pop_all().close) self.addCleanup(stack.pop_all().close)
return super().setUp() return super().setUp()
@ -400,12 +397,14 @@ class TestDiagnosticsInfra(common_utils.TestCase):
}, },
): ):
diagnostic1 = infra.Diagnostic( diagnostic1 = infra.Diagnostic(
custom_rules.custom_rule, infra.Level.WARNING # type: ignore[attr-defined] custom_rules.custom_rule, # type: ignore[attr-defined]
infra.Level.WARNING,
) )
self.context.log(diagnostic1) self.context.log(diagnostic1)
diagnostic2 = infra.Diagnostic( diagnostic2 = infra.Diagnostic(
custom_rules.custom_rule_2, infra.Level.ERROR # type: ignore[attr-defined] custom_rules.custom_rule_2, # type: ignore[attr-defined]
infra.Level.ERROR,
) )
self.context.log(diagnostic2) self.context.log(diagnostic2)

View File

@ -49,9 +49,9 @@ class TestGlobalHelpers(common_utils.TestCase):
class TestOverrideDict(common_utils.TestCase): class TestOverrideDict(common_utils.TestCase):
def setUp(self): def setUp(self):
self.override_dict: registration.OverrideDict[ self.override_dict: registration.OverrideDict[str, int] = (
str, int registration.OverrideDict()
] = registration.OverrideDict() )
def test_get_item_returns_base_value_when_no_override(self): def test_get_item_returns_base_value_when_no_override(self):
self.override_dict.set_base("a", 42) self.override_dict.set_base("a", 42)

View File

@ -44,7 +44,7 @@ class _netG(nn.Module):
nn.ReLU(True), nn.ReLU(True),
# state size. (ngf) x 32 x 32 # state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh() nn.Tanh(),
# state size. (nc) x 64 x 64 # state size. (nc) x 64 x 64
) )

View File

@ -294,8 +294,8 @@ def xfail(error_message: str, reason: Optional[str] = None):
except Exception as e: except Exception as e:
if isinstance(e, torch.onnx.OnnxExporterError): if isinstance(e, torch.onnx.OnnxExporterError):
# diagnostic message is in the cause of the exception # diagnostic message is in the cause of the exception
assert error_message in str( assert (
e.__cause__ error_message in str(e.__cause__)
), f"Expected error message: {error_message} NOT in {str(e.__cause__)}" ), f"Expected error message: {error_message} NOT in {str(e.__cause__)}"
else: else:
assert error_message in str( assert error_message in str(

View File

@ -175,9 +175,7 @@ def _init_test_roi_heads_faster_rcnn():
resolution = box_roi_pool.output_size[0] resolution = box_roi_pool.output_size[0]
representation_size = 1024 representation_size = 1024
box_head = faster_rcnn.TwoMLPHead( box_head = faster_rcnn.TwoMLPHead(out_channels * resolution**2, representation_size)
out_channels * resolution**2, representation_size
)
representation_size = 1024 representation_size = 1024
box_predictor = faster_rcnn.FastRCNNPredictor(representation_size, num_classes) box_predictor = faster_rcnn.FastRCNNPredictor(representation_size, num_classes)

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: onnx"] # Owner(s): ["module: onnx"]
"""Test the support on onnxscript in PyTorch-ONNX converter.""" """Test the support on onnxscript in PyTorch-ONNX converter."""
import io import io
from typing import List from typing import List

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: onnx"] # Owner(s): ["module: onnx"]
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime.""" """Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
from typing import List from typing import List
import onnx_test_common import onnx_test_common

View File

@ -6,6 +6,7 @@ Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
--produce-onnx-test-data: generate onnx test data --produce-onnx-test-data: generate onnx test data
--accept: accept onnx updates and overwrite models --accept: accept onnx updates and overwrite models
""" """
import glob import glob
import inspect import inspect
import io import io
@ -879,7 +880,8 @@ class TestOperators(common_utils.TestCase):
def forward(self, x_in): def forward(self, x_in):
x_out = {} x_out = {}
x_out["test_key_out"] = torch.add( x_out["test_key_out"] = torch.add(
x_in[list(x_in.keys())[0]], list(x_in.keys())[0] # noqa: RUF015 x_in[list(x_in.keys())[0]], # noqa: RUF015
list(x_in.keys())[0], # noqa: RUF015
) )
return x_out return x_out

View File

@ -483,7 +483,8 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
def forward(self, x_in): def forward(self, x_in):
x_out = {} x_out = {}
x_out["test_key_out"] = torch.add( x_out["test_key_out"] = torch.add(
x_in[list(x_in.keys())[0]], list(x_in.keys())[0] # noqa: RUF015 x_in[list(x_in.keys())[0]], # noqa: RUF015
list(x_in.keys())[0], # noqa: RUF015
) )
return x_out return x_out

View File

@ -174,7 +174,9 @@ class TestUnconvertibleOps(pytorch_test_common.ExportTestCase):
_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET + 1, _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET + 1,
) )
], ],
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}", class_name_func=lambda cls,
num,
params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}",
) )
class TestUtilityFuns(_BaseTestCase): class TestUtilityFuns(_BaseTestCase):
opset_version = None opset_version = None

View File

@ -185,7 +185,9 @@ class TestVerificationOnWrongExport(pytorch_test_common.ExportTestCase):
# {"onnx_backend": verification.OnnxBackend.ONNX}, # {"onnx_backend": verification.OnnxBackend.ONNX},
{"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU}, {"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU},
], ],
class_name_func=lambda cls, idx, input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}", class_name_func=lambda cls,
idx,
input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}",
) )
class TestFindMismatch(pytorch_test_common.ExportTestCase): class TestFindMismatch(pytorch_test_common.ExportTestCase):
onnx_backend: verification.OnnxBackend onnx_backend: verification.OnnxBackend

View File

@ -79,7 +79,9 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
with tempfile.NamedTemporaryFile(suffix=".pte") as f: with tempfile.NamedTemporaryFile(suffix=".pte") as f:
torch.export.save(exported_program, f.name) torch.export.save(exported_program, f.name)
del exported_program # Delete the exported program to ensure that we are loading from file del (
exported_program
) # Delete the exported program to ensure that we are loading from file
loaded_exported_program = torch.export.load(f.name) loaded_exported_program = torch.export.load(f.name)
self._compare_onnx_and_torch_exported_program( self._compare_onnx_and_torch_exported_program(

View File

@ -47,8 +47,12 @@ USE_BLACK_FILELIST = re.compile(
"test/[a-h]*/**", "test/[a-h]*/**",
# test/[i-j]*/** # test/[i-j]*/**
"test/[i-j]*/**", "test/[i-j]*/**",
# test/[k-z]*/** # test/[k-n]*/**
"test/[k-z]*/**", "test/[k-n]*/**",
# test/optim/**
"test/optim/**",
# "test/[p-z]*/**",
"test/[p-z]*/**",
# torch/** # torch/**
# torch/_[a-h]*/** # torch/_[a-h]*/**
"torch/_[a-h]*/**", "torch/_[a-h]*/**",
@ -62,8 +66,10 @@ USE_BLACK_FILELIST = re.compile(
"torch/d*/**", "torch/d*/**",
# torch/[e-n]*/** # torch/[e-n]*/**
"torch/[e-n]*/**", "torch/[e-n]*/**",
# torch/[o-z]*/** # torch/optim/**
"torch/[o-z]*/**", "torch/optim/**",
# torch/[p-z]*/**
"torch/[p-z]*/**",
], ],
), ),
) )

View File

@ -84,7 +84,6 @@ from .utils import (
from . import ( # usort: skip. Keep the order instead of sorting lexicographically from . import ( # usort: skip. Keep the order instead of sorting lexicographically
_deprecation,
errors, errors,
symbolic_caffe2, symbolic_caffe2,
symbolic_helper, symbolic_helper,
@ -215,12 +214,13 @@ def export(
def forward(self, x): def forward(self, x):
return torch.sum(x, dim=1) return torch.sum(x, dim=1)
torch.onnx.export( torch.onnx.export(
SumModule(), SumModule(),
(torch.ones(2, 2),), (torch.ones(2, 2),),
"onnx.pb", "onnx.pb",
input_names=["x"], input_names=["x"],
output_names=["sum"] output_names=["sum"],
) )
Produces:: Produces::
@ -256,7 +256,7 @@ def export(
"x": {0: "my_custom_axis_name"}, "x": {0: "my_custom_axis_name"},
# list value: automatic names # list value: automatic names
"sum": [0], "sum": [0],
} },
) )
Produces:: Produces::

View File

@ -6,6 +6,7 @@ Do not use this module outside of `torch.onnx` and its tests.
Be very judicious when adding any new global variables. Do not create new global Be very judicious when adding any new global variables. Do not create new global
variables unless they are absolutely necessary. variables unless they are absolutely necessary.
""" """
import torch._C._onnx as _C_onnx import torch._C._onnx as _C_onnx
# This module should only depend on _constants and nothing else in torch.onnx to keep # This module should only depend on _constants and nothing else in torch.onnx to keep

View File

@ -107,9 +107,9 @@ class OnnxRegistry:
# NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important # NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important
# not to directly modify this variable. Instead, access to it should be done through # not to directly modify this variable. Instead, access to it should be done through
# the public methods: register_custom_op, get_ops, and is_registered_op. # the public methods: register_custom_op, get_ops, and is_registered_op.
self._registry: dict[ self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = (
registration.OpName, list[registration.ONNXFunction] defaultdict(list)
] = defaultdict(list) )
# FIXME: Avoid importing onnxscript into torch # FIXME: Avoid importing onnxscript into torch
from onnxscript.function_libs.torch_lib import ( # type: ignore[import] # noqa: F401 from onnxscript.function_libs.torch_lib import ( # type: ignore[import] # noqa: F401
registration, registration,
@ -392,8 +392,10 @@ class ResolvedExportOptions(ExportOptions):
) )
self.onnx_registry = resolve(options.onnx_registry, OnnxRegistry()) self.onnx_registry = resolve(options.onnx_registry, OnnxRegistry())
self.decomposition_table = decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment] self.decomposition_table = (
self.onnx_registry decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment]
self.onnx_registry
)
) )
from torch.onnx._internal.fx import onnxfunction_dispatcher from torch.onnx._internal.fx import onnxfunction_dispatcher
@ -766,6 +768,7 @@ class ONNXProgram:
... self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, bias=False) ... self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, bias=False)
... self.fc1 = torch.nn.Linear(9216, 128, bias=False) ... self.fc1 = torch.nn.Linear(9216, 128, bias=False)
... self.fc2 = torch.nn.Linear(128, 10, bias=False) ... self.fc2 = torch.nn.Linear(128, 10, bias=False)
...
... def forward(self, x, b): ... def forward(self, x, b):
... tensor_x = self.conv1(x) ... tensor_x = self.conv1(x)
... tensor_x = torch.nn.functional.sigmoid(tensor_x) ... tensor_x = torch.nn.functional.sigmoid(tensor_x)
@ -778,11 +781,13 @@ class ONNXProgram:
... tensor_x = self.fc2(tensor_x) ... tensor_x = self.fc2(tensor_x)
... output = torch.nn.functional.log_softmax(tensor_x, dim=1) ... output = torch.nn.functional.log_softmax(tensor_x, dim=1)
... ( ... (
... self.my_buffer2.add_(1.0) + self.my_buffer1 ... self.my_buffer2.add_(1.0) + self.my_buffer1
... ) # Mutate buffer through in-place addition ... ) # Mutate buffer through in-place addition
... return output ... return output
>>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3)) >>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3))
>>> exported_program = torch.export.export(CustomModule(), args=inputs).run_decompositions({}) >>> exported_program = torch.export.export(
... CustomModule(), args=inputs
... ).run_decompositions({})
>>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs) >>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs)
>>> pprint.pprint(onnx_program.model_signature) >>> pprint.pprint(onnx_program.model_signature)
ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>,
@ -1194,9 +1199,7 @@ class Exporter:
with self.options.diagnostic_context, decomposition_skip.enable_decomposition_skips( with self.options.diagnostic_context, decomposition_skip.enable_decomposition_skips(
self.options self.options
), torch._dynamo.config.patch( ), torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)
):
graph_module = self.options.fx_tracer.generate_fx( graph_module = self.options.fx_tracer.generate_fx(
self.options, self.model, self.model_args, self.model_kwargs self.options, self.model, self.model_args, self.model_kwargs
) )
@ -1401,17 +1404,19 @@ def dynamo_export(
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.linear = torch.nn.Linear(2, 2) self.linear = torch.nn.Linear(2, 2)
def forward(self, x, bias=None): def forward(self, x, bias=None):
out = self.linear(x) out = self.linear(x)
out = out + bias out = out + bias
return out return out
model = MyModel() model = MyModel()
kwargs = {"bias": 3.} kwargs = {"bias": 3.0}
args = (torch.randn(2, 2, 2),) args = (torch.randn(2, 2, 2),)
onnx_program = torch.onnx.dynamo_export( onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
model, "my_simple_model.onnx"
*args, )
**kwargs).save("my_simple_model.onnx")
**Example 2 - Exporting with dynamic shapes** **Example 2 - Exporting with dynamic shapes**
:: ::
@ -1419,10 +1424,8 @@ def dynamo_export(
# The previous model can be exported with dynamic shapes # The previous model can be exported with dynamic shapes
export_options = torch.onnx.ExportOptions(dynamic_shapes=True) export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_program = torch.onnx.dynamo_export( onnx_program = torch.onnx.dynamo_export(
model, model, *args, **kwargs, export_options=export_options
*args, )
**kwargs,
export_options=export_options)
onnx_program.save("my_dynamic_model.onnx") onnx_program.save("my_dynamic_model.onnx")

View File

@ -1,4 +1,5 @@
"""Utility to lazily import modules.""" """Utility to lazily import modules."""
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations from __future__ import annotations

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`.""" """Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`."""
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib

View File

@ -22,9 +22,9 @@ class ArtifactContent(object):
properties: Optional[_property_bag.PropertyBag] = dataclasses.field( properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"} default=None, metadata={"schema_property_name": "properties"}
) )
rendered: Optional[ rendered: Optional[_multiformat_message_string.MultiformatMessageString] = (
_multiformat_message_string.MultiformatMessageString dataclasses.field(default=None, metadata={"schema_property_name": "rendered"})
] = dataclasses.field(default=None, metadata={"schema_property_name": "rendered"}) )
text: Optional[str] = dataclasses.field( text: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "text"} default=None, metadata={"schema_property_name": "text"}
) )

View File

@ -19,10 +19,10 @@ class Conversion(object):
"""Describes how a converter transformed the output of a static analysis tool from the analysis tool's native output format into the SARIF format.""" """Describes how a converter transformed the output of a static analysis tool from the analysis tool's native output format into the SARIF format."""
tool: _tool.Tool = dataclasses.field(metadata={"schema_property_name": "tool"}) tool: _tool.Tool = dataclasses.field(metadata={"schema_property_name": "tool"})
analysis_tool_log_files: Optional[ analysis_tool_log_files: Optional[List[_artifact_location.ArtifactLocation]] = (
List[_artifact_location.ArtifactLocation] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "analysisToolLogFiles"}
default=None, metadata={"schema_property_name": "analysisToolLogFiles"} )
) )
invocation: Optional[_invocation.Invocation] = dataclasses.field( invocation: Optional[_invocation.Invocation] = dataclasses.field(
default=None, metadata={"schema_property_name": "invocation"} default=None, metadata={"schema_property_name": "invocation"}

View File

@ -53,10 +53,10 @@ class ExternalProperties(object):
invocations: Optional[List[_invocation.Invocation]] = dataclasses.field( invocations: Optional[List[_invocation.Invocation]] = dataclasses.field(
default=None, metadata={"schema_property_name": "invocations"} default=None, metadata={"schema_property_name": "invocations"}
) )
logical_locations: Optional[ logical_locations: Optional[List[_logical_location.LogicalLocation]] = (
List[_logical_location.LogicalLocation] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "logicalLocations"}
default=None, metadata={"schema_property_name": "logicalLocations"} )
) )
policies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( policies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "policies"} default=None, metadata={"schema_property_name": "policies"}
@ -76,10 +76,10 @@ class ExternalProperties(object):
taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "taxonomies"} default=None, metadata={"schema_property_name": "taxonomies"}
) )
thread_flow_locations: Optional[ thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = (
List[_thread_flow_location.ThreadFlowLocation] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "threadFlowLocations"}
default=None, metadata={"schema_property_name": "threadFlowLocations"} )
) )
translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "translations"} default=None, metadata={"schema_property_name": "translations"}

View File

@ -36,10 +36,10 @@ class Invocation(object):
environment_variables: Any = dataclasses.field( environment_variables: Any = dataclasses.field(
default=None, metadata={"schema_property_name": "environmentVariables"} default=None, metadata={"schema_property_name": "environmentVariables"}
) )
executable_location: Optional[ executable_location: Optional[_artifact_location.ArtifactLocation] = (
_artifact_location.ArtifactLocation dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "executableLocation"}
default=None, metadata={"schema_property_name": "executableLocation"} )
) )
exit_code: Optional[int] = dataclasses.field( exit_code: Optional[int] = dataclasses.field(
default=None, metadata={"schema_property_name": "exitCode"} default=None, metadata={"schema_property_name": "exitCode"}
@ -71,10 +71,10 @@ class Invocation(object):
properties: Optional[_property_bag.PropertyBag] = dataclasses.field( properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"} default=None, metadata={"schema_property_name": "properties"}
) )
response_files: Optional[ response_files: Optional[List[_artifact_location.ArtifactLocation]] = (
List[_artifact_location.ArtifactLocation] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "responseFiles"}
default=None, metadata={"schema_property_name": "responseFiles"} )
) )
rule_configuration_overrides: Optional[ rule_configuration_overrides: Optional[
List[_configuration_override.ConfigurationOverride] List[_configuration_override.ConfigurationOverride]
@ -96,21 +96,22 @@ class Invocation(object):
stdout_stderr: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( stdout_stderr: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
default=None, metadata={"schema_property_name": "stdoutStderr"} default=None, metadata={"schema_property_name": "stdoutStderr"}
) )
tool_configuration_notifications: Optional[ tool_configuration_notifications: Optional[List[_notification.Notification]] = (
List[_notification.Notification] dataclasses.field(
] = dataclasses.field( default=None,
default=None, metadata={"schema_property_name": "toolConfigurationNotifications"},
metadata={"schema_property_name": "toolConfigurationNotifications"}, )
) )
tool_execution_notifications: Optional[ tool_execution_notifications: Optional[List[_notification.Notification]] = (
List[_notification.Notification] dataclasses.field(
] = dataclasses.field( default=None,
default=None, metadata={"schema_property_name": "toolExecutionNotifications"} metadata={"schema_property_name": "toolExecutionNotifications"},
)
) )
working_directory: Optional[ working_directory: Optional[_artifact_location.ArtifactLocation] = (
_artifact_location.ArtifactLocation dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "workingDirectory"}
default=None, metadata={"schema_property_name": "workingDirectory"} )
) )

View File

@ -24,26 +24,26 @@ class Location(object):
default=None, metadata={"schema_property_name": "annotations"} default=None, metadata={"schema_property_name": "annotations"}
) )
id: int = dataclasses.field(default=-1, metadata={"schema_property_name": "id"}) id: int = dataclasses.field(default=-1, metadata={"schema_property_name": "id"})
logical_locations: Optional[ logical_locations: Optional[List[_logical_location.LogicalLocation]] = (
List[_logical_location.LogicalLocation] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "logicalLocations"}
default=None, metadata={"schema_property_name": "logicalLocations"} )
) )
message: Optional[_message.Message] = dataclasses.field( message: Optional[_message.Message] = dataclasses.field(
default=None, metadata={"schema_property_name": "message"} default=None, metadata={"schema_property_name": "message"}
) )
physical_location: Optional[ physical_location: Optional[_physical_location.PhysicalLocation] = (
_physical_location.PhysicalLocation dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "physicalLocation"}
default=None, metadata={"schema_property_name": "physicalLocation"} )
) )
properties: Optional[_property_bag.PropertyBag] = dataclasses.field( properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"} default=None, metadata={"schema_property_name": "properties"}
) )
relationships: Optional[ relationships: Optional[List[_location_relationship.LocationRelationship]] = (
List[_location_relationship.LocationRelationship] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "relationships"}
default=None, metadata={"schema_property_name": "relationships"} )
) )

View File

@ -21,10 +21,10 @@ class PhysicalLocation(object):
address: Optional[_address.Address] = dataclasses.field( address: Optional[_address.Address] = dataclasses.field(
default=None, metadata={"schema_property_name": "address"} default=None, metadata={"schema_property_name": "address"}
) )
artifact_location: Optional[ artifact_location: Optional[_artifact_location.ArtifactLocation] = (
_artifact_location.ArtifactLocation dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "artifactLocation"}
default=None, metadata={"schema_property_name": "artifactLocation"} )
) )
context_region: Optional[_region.Region] = dataclasses.field( context_region: Optional[_region.Region] = dataclasses.field(
default=None, metadata={"schema_property_name": "contextRegion"} default=None, metadata={"schema_property_name": "contextRegion"}

View File

@ -19,10 +19,10 @@ class ReportingDescriptor(object):
"""Metadata that describes a specific report produced by the tool, as part of the analysis it provides or its runtime reporting.""" """Metadata that describes a specific report produced by the tool, as part of the analysis it provides or its runtime reporting."""
id: str = dataclasses.field(metadata={"schema_property_name": "id"}) id: str = dataclasses.field(metadata={"schema_property_name": "id"})
default_configuration: Optional[ default_configuration: Optional[_reporting_configuration.ReportingConfiguration] = (
_reporting_configuration.ReportingConfiguration dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "defaultConfiguration"}
default=None, metadata={"schema_property_name": "defaultConfiguration"} )
) )
deprecated_guids: Optional[List[str]] = dataclasses.field( deprecated_guids: Optional[List[str]] = dataclasses.field(
default=None, metadata={"schema_property_name": "deprecatedGuids"} default=None, metadata={"schema_property_name": "deprecatedGuids"}
@ -33,17 +33,17 @@ class ReportingDescriptor(object):
deprecated_names: Optional[List[str]] = dataclasses.field( deprecated_names: Optional[List[str]] = dataclasses.field(
default=None, metadata={"schema_property_name": "deprecatedNames"} default=None, metadata={"schema_property_name": "deprecatedNames"}
) )
full_description: Optional[ full_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
_multiformat_message_string.MultiformatMessageString dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "fullDescription"}
default=None, metadata={"schema_property_name": "fullDescription"} )
) )
guid: Optional[str] = dataclasses.field( guid: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "guid"} default=None, metadata={"schema_property_name": "guid"}
) )
help: Optional[ help: Optional[_multiformat_message_string.MultiformatMessageString] = (
_multiformat_message_string.MultiformatMessageString dataclasses.field(default=None, metadata={"schema_property_name": "help"})
] = dataclasses.field(default=None, metadata={"schema_property_name": "help"}) )
help_uri: Optional[str] = dataclasses.field( help_uri: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "helpUri"} default=None, metadata={"schema_property_name": "helpUri"}
) )

View File

@ -28,10 +28,10 @@ class ReportingDescriptorReference(object):
properties: Optional[_property_bag.PropertyBag] = dataclasses.field( properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"} default=None, metadata={"schema_property_name": "properties"}
) )
tool_component: Optional[ tool_component: Optional[_tool_component_reference.ToolComponentReference] = (
_tool_component_reference.ToolComponentReference dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "toolComponent"}
default=None, metadata={"schema_property_name": "toolComponent"} )
) )

View File

@ -38,10 +38,10 @@ class Result(object):
attachments: Optional[List[_attachment.Attachment]] = dataclasses.field( attachments: Optional[List[_attachment.Attachment]] = dataclasses.field(
default=None, metadata={"schema_property_name": "attachments"} default=None, metadata={"schema_property_name": "attachments"}
) )
baseline_state: Optional[ baseline_state: Optional[Literal["new", "unchanged", "updated", "absent"]] = (
Literal["new", "unchanged", "updated", "absent"] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "baselineState"}
default=None, metadata={"schema_property_name": "baselineState"} )
) )
code_flows: Optional[List[_code_flow.CodeFlow]] = dataclasses.field( code_flows: Optional[List[_code_flow.CodeFlow]] = dataclasses.field(
default=None, metadata={"schema_property_name": "codeFlows"} default=None, metadata={"schema_property_name": "codeFlows"}
@ -55,10 +55,10 @@ class Result(object):
fixes: Optional[List[_fix.Fix]] = dataclasses.field( fixes: Optional[List[_fix.Fix]] = dataclasses.field(
default=None, metadata={"schema_property_name": "fixes"} default=None, metadata={"schema_property_name": "fixes"}
) )
graph_traversals: Optional[ graph_traversals: Optional[List[_graph_traversal.GraphTraversal]] = (
List[_graph_traversal.GraphTraversal] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "graphTraversals"}
default=None, metadata={"schema_property_name": "graphTraversals"} )
) )
graphs: Optional[List[_graph.Graph]] = dataclasses.field( graphs: Optional[List[_graph.Graph]] = dataclasses.field(
default=None, metadata={"schema_property_name": "graphs"} default=None, metadata={"schema_property_name": "graphs"}
@ -96,9 +96,9 @@ class Result(object):
related_locations: Optional[List[_location.Location]] = dataclasses.field( related_locations: Optional[List[_location.Location]] = dataclasses.field(
default=None, metadata={"schema_property_name": "relatedLocations"} default=None, metadata={"schema_property_name": "relatedLocations"}
) )
rule: Optional[ rule: Optional[_reporting_descriptor_reference.ReportingDescriptorReference] = (
_reporting_descriptor_reference.ReportingDescriptorReference dataclasses.field(default=None, metadata={"schema_property_name": "rule"})
] = dataclasses.field(default=None, metadata={"schema_property_name": "rule"}) )
rule_id: Optional[str] = dataclasses.field( rule_id: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "ruleId"} default=None, metadata={"schema_property_name": "ruleId"}
) )

View File

@ -16,10 +16,10 @@ from torch.onnx._internal.diagnostics.infra.sarif import (
class ResultProvenance(object): class ResultProvenance(object):
"""Contains information about how and when a result was detected.""" """Contains information about how and when a result was detected."""
conversion_sources: Optional[ conversion_sources: Optional[List[_physical_location.PhysicalLocation]] = (
List[_physical_location.PhysicalLocation] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "conversionSources"}
default=None, metadata={"schema_property_name": "conversionSources"} )
) )
first_detection_run_guid: Optional[str] = dataclasses.field( first_detection_run_guid: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "firstDetectionRunGuid"} default=None, metadata={"schema_property_name": "firstDetectionRunGuid"}

View File

@ -38,17 +38,17 @@ class Run(object):
artifacts: Optional[List[_artifact.Artifact]] = dataclasses.field( artifacts: Optional[List[_artifact.Artifact]] = dataclasses.field(
default=None, metadata={"schema_property_name": "artifacts"} default=None, metadata={"schema_property_name": "artifacts"}
) )
automation_details: Optional[ automation_details: Optional[_run_automation_details.RunAutomationDetails] = (
_run_automation_details.RunAutomationDetails dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "automationDetails"}
default=None, metadata={"schema_property_name": "automationDetails"} )
) )
baseline_guid: Optional[str] = dataclasses.field( baseline_guid: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "baselineGuid"} default=None, metadata={"schema_property_name": "baselineGuid"}
) )
column_kind: Optional[ column_kind: Optional[Literal["utf16CodeUnits", "unicodeCodePoints"]] = (
Literal["utf16CodeUnits", "unicodeCodePoints"] dataclasses.field(default=None, metadata={"schema_property_name": "columnKind"})
] = dataclasses.field(default=None, metadata={"schema_property_name": "columnKind"}) )
conversion: Optional[_conversion.Conversion] = dataclasses.field( conversion: Optional[_conversion.Conversion] = dataclasses.field(
default=None, metadata={"schema_property_name": "conversion"} default=None, metadata={"schema_property_name": "conversion"}
) )
@ -73,10 +73,10 @@ class Run(object):
language: str = dataclasses.field( language: str = dataclasses.field(
default="en-US", metadata={"schema_property_name": "language"} default="en-US", metadata={"schema_property_name": "language"}
) )
logical_locations: Optional[ logical_locations: Optional[List[_logical_location.LogicalLocation]] = (
List[_logical_location.LogicalLocation] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "logicalLocations"}
default=None, metadata={"schema_property_name": "logicalLocations"} )
) )
newline_sequences: List[str] = dataclasses.field( newline_sequences: List[str] = dataclasses.field(
default_factory=lambda: ["\r\n", "\n"], default_factory=lambda: ["\r\n", "\n"],
@ -97,23 +97,23 @@ class Run(object):
results: Optional[List[_result.Result]] = dataclasses.field( results: Optional[List[_result.Result]] = dataclasses.field(
default=None, metadata={"schema_property_name": "results"} default=None, metadata={"schema_property_name": "results"}
) )
run_aggregates: Optional[ run_aggregates: Optional[List[_run_automation_details.RunAutomationDetails]] = (
List[_run_automation_details.RunAutomationDetails] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "runAggregates"}
default=None, metadata={"schema_property_name": "runAggregates"} )
) )
special_locations: Optional[ special_locations: Optional[_special_locations.SpecialLocations] = (
_special_locations.SpecialLocations dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "specialLocations"}
default=None, metadata={"schema_property_name": "specialLocations"} )
) )
taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "taxonomies"} default=None, metadata={"schema_property_name": "taxonomies"}
) )
thread_flow_locations: Optional[ thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = (
List[_thread_flow_location.ThreadFlowLocation] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "threadFlowLocations"}
default=None, metadata={"schema_property_name": "threadFlowLocations"} )
) )
translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "translations"} default=None, metadata={"schema_property_name": "translations"}

View File

@ -21,10 +21,10 @@ class ToolComponent(object):
"""A component, such as a plug-in or the driver, of the analysis tool that was run.""" """A component, such as a plug-in or the driver, of the analysis tool that was run."""
name: str = dataclasses.field(metadata={"schema_property_name": "name"}) name: str = dataclasses.field(metadata={"schema_property_name": "name"})
associated_component: Optional[ associated_component: Optional[_tool_component_reference.ToolComponentReference] = (
_tool_component_reference.ToolComponentReference dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "associatedComponent"}
default=None, metadata={"schema_property_name": "associatedComponent"} )
) )
contents: List[Literal["localizedData", "nonLocalizedData"]] = dataclasses.field( contents: List[Literal["localizedData", "nonLocalizedData"]] = dataclasses.field(
default_factory=lambda: ["localizedData", "nonLocalizedData"], default_factory=lambda: ["localizedData", "nonLocalizedData"],
@ -36,10 +36,10 @@ class ToolComponent(object):
download_uri: Optional[str] = dataclasses.field( download_uri: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "downloadUri"} default=None, metadata={"schema_property_name": "downloadUri"}
) )
full_description: Optional[ full_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
_multiformat_message_string.MultiformatMessageString dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "fullDescription"}
default=None, metadata={"schema_property_name": "fullDescription"} )
) )
full_name: Optional[str] = dataclasses.field( full_name: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "fullName"} default=None, metadata={"schema_property_name": "fullName"}
@ -71,10 +71,10 @@ class ToolComponent(object):
"schema_property_name": "minimumRequiredLocalizedDataSemanticVersion" "schema_property_name": "minimumRequiredLocalizedDataSemanticVersion"
}, },
) )
notifications: Optional[ notifications: Optional[List[_reporting_descriptor.ReportingDescriptor]] = (
List[_reporting_descriptor.ReportingDescriptor] dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "notifications"}
default=None, metadata={"schema_property_name": "notifications"} )
) )
organization: Optional[str] = dataclasses.field( organization: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "organization"} default=None, metadata={"schema_property_name": "organization"}
@ -91,9 +91,9 @@ class ToolComponent(object):
release_date_utc: Optional[str] = dataclasses.field( release_date_utc: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "releaseDateUtc"} default=None, metadata={"schema_property_name": "releaseDateUtc"}
) )
rules: Optional[ rules: Optional[List[_reporting_descriptor.ReportingDescriptor]] = (
List[_reporting_descriptor.ReportingDescriptor] dataclasses.field(default=None, metadata={"schema_property_name": "rules"})
] = dataclasses.field(default=None, metadata={"schema_property_name": "rules"}) )
semantic_version: Optional[str] = dataclasses.field( semantic_version: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "semanticVersion"} default=None, metadata={"schema_property_name": "semanticVersion"}
) )
@ -110,10 +110,10 @@ class ToolComponent(object):
taxa: Optional[List[_reporting_descriptor.ReportingDescriptor]] = dataclasses.field( taxa: Optional[List[_reporting_descriptor.ReportingDescriptor]] = dataclasses.field(
default=None, metadata={"schema_property_name": "taxa"} default=None, metadata={"schema_property_name": "taxa"}
) )
translation_metadata: Optional[ translation_metadata: Optional[_translation_metadata.TranslationMetadata] = (
_translation_metadata.TranslationMetadata dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "translationMetadata"}
default=None, metadata={"schema_property_name": "translationMetadata"} )
) )
version: Optional[str] = dataclasses.field( version: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "version"} default=None, metadata={"schema_property_name": "version"}

View File

@ -20,10 +20,10 @@ class TranslationMetadata(object):
download_uri: Optional[str] = dataclasses.field( download_uri: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "downloadUri"} default=None, metadata={"schema_property_name": "downloadUri"}
) )
full_description: Optional[ full_description: Optional[_multiformat_message_string.MultiformatMessageString] = (
_multiformat_message_string.MultiformatMessageString dataclasses.field(
] = dataclasses.field( default=None, metadata={"schema_property_name": "fullDescription"}
default=None, metadata={"schema_property_name": "fullDescription"} )
) )
full_name: Optional[str] = dataclasses.field( full_name: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "fullName"} default=None, metadata={"schema_property_name": "fullName"}

View File

@ -808,9 +808,9 @@ def _exported_program_to_onnx_program(
value, Sequence value, Sequence
), f"Input '{value_name}' should not be a sequence. This is unexpected." ), f"Input '{value_name}' should not be a sequence. This is unexpected."
value.metadata_props[ value.metadata_props["pkg.torch.export.graph_signature.InputSpec.kind"] = (
"pkg.torch.export.graph_signature.InputSpec.kind" input_kind.name
] = input_kind.name )
value.metadata_props[ value.metadata_props[
"pkg.torch.export.graph_signature.InputSpec.persistent" "pkg.torch.export.graph_signature.InputSpec.persistent"
] = str(persistent) ] = str(persistent)
@ -859,9 +859,9 @@ def _exported_program_to_onnx_program(
) )
for value in _values: for value in _values:
value.metadata_props[ value.metadata_props["pkg.torch.export.graph_signature.OutputSpec.kind"] = (
"pkg.torch.export.graph_signature.OutputSpec.kind" output_kind.name
] = output_kind.name )
if output_kind == graph_signature.OutputKind.USER_OUTPUT: if output_kind == graph_signature.OutputKind.USER_OUTPUT:
model.graph.outputs.append(value) model.graph.outputs.append(value)
@ -1218,7 +1218,9 @@ def export(
if byte_size < 2 * 1024 * 1024 * 1024: if byte_size < 2 * 1024 * 1024 * 1024:
# The checker may segfault so we need to run it in a separate process # The checker may segfault so we need to run it in a separate process
_isolated.safe_call( _isolated.safe_call(
onnx.checker.check_model, onnx_program.model_proto, full_check=True # type: ignore[attr-defined] onnx.checker.check_model,
onnx_program.model_proto,
full_check=True, # type: ignore[attr-defined]
) )
export_status.onnx_checker = True export_status.onnx_checker = True
verbose_print("Run `onnx.checker` on the ONNX model... ✅") verbose_print("Run `onnx.checker` on the ONNX model... ✅")
@ -1312,9 +1314,7 @@ def export(
_format_exceptions_for_all_strategies(failed_results) _format_exceptions_for_all_strategies(failed_results)
) )
if onnx_runtime_error_message: if onnx_runtime_error_message:
traceback_lines.append( traceback_lines.append("# ⚠️ ONNX Runtime error -----------------------")
"# ⚠️ ONNX Runtime error -----------------------"
)
traceback_lines.append(onnx_runtime_error_message) traceback_lines.append(onnx_runtime_error_message)
if not traceback_lines: if not traceback_lines:
traceback_lines.append("No errors") traceback_lines.append("No errors")

View File

@ -304,8 +304,8 @@ def _get_allowed_types_from_type_annotation(
allowed_types = set() allowed_types = set()
subtypes = typing.get_args(type_) subtypes = typing.get_args(type_)
for subtype in subtypes: for subtype in subtypes:
assert subtype is not type( assert (
None subtype is not type(None)
), "Union should not contain None type because it is handled by _is_optional." ), "Union should not contain None type because it is handled by _is_optional."
allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) allowed_types.update(_get_allowed_types_from_type_annotation(subtype))
return allowed_types return allowed_types

View File

@ -235,8 +235,7 @@ class Transform(abc.ABC):
) )
@abc.abstractmethod @abc.abstractmethod
def _run(self, *args, **kwargs) -> torch.fx.GraphModule: def _run(self, *args, **kwargs) -> torch.fx.GraphModule: ...
...
@diagnostics.diagnose_call( @diagnostics.diagnose_call(
diagnostics.rules.fx_pass, diagnostics.rules.fx_pass,
@ -321,5 +320,4 @@ class Analysis(abc.ABC):
self.onnxfunction_dispatcher = onnxfunction_dispatcher self.onnxfunction_dispatcher = onnxfunction_dispatcher
@abc.abstractmethod @abc.abstractmethod
def analyze(self, diagnostic_level: diagnostics.infra.Level) -> AnalysisResult: def analyze(self, diagnostic_level: diagnostics.infra.Level) -> AnalysisResult: ...
...

View File

@ -11,6 +11,7 @@ https://github.com/pytorch/pytorch/issues/115883
This solution will no longer be required once the issue is resolved. This solution will no longer be required once the issue is resolved.
""" """
from __future__ import annotations from __future__ import annotations
import abc import abc

View File

@ -94,7 +94,9 @@ class _PyTreeExtensionContext:
for _, class_type in named_model_output_classes: for _, class_type in named_model_output_classes:
self.register_pytree_node( self.register_pytree_node(
class_type, model_output_flatten, model_output_unflatten # type: ignore[arg-type ] class_type,
model_output_flatten,
model_output_unflatten, # type: ignore[arg-type ]
) )

View File

@ -626,7 +626,8 @@ class FxOnnxInterpreter:
): ):
# aten ops and other stateless functions. # aten ops and other stateless functions.
if node.target == operator.getitem and isinstance( if node.target == operator.getitem and isinstance(
fx_name_to_onnxscript_value[node.args[0].name], tuple # type: ignore[union-attr,index] fx_name_to_onnxscript_value[node.args[0].name], # type: ignore[union-attr,index]
tuple,
): ):
onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index] onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index]
index = node.args[1] index = node.args[1]
@ -660,9 +661,10 @@ class FxOnnxInterpreter:
diagnostic_context=self.diagnostic_context, diagnostic_context=self.diagnostic_context,
) )
with onnxscript.evaluator.default_as(onnxscript_tracer): with onnxscript.evaluator.default_as(onnxscript_tracer):
output: onnxscript_graph_building.TorchScriptTensor | tuple[ output: (
onnxscript_graph_building.TorchScriptTensor, ... onnxscript_graph_building.TorchScriptTensor
] = symbolic_fn(*onnx_args, **onnx_kwargs) | tuple[onnxscript_graph_building.TorchScriptTensor, ...]
) = symbolic_fn(*onnx_args, **onnx_kwargs)
assert ( assert (
output is not None output is not None
), f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}" ), f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}"
@ -779,9 +781,10 @@ class FxOnnxInterpreter:
# be considered. # be considered.
unique_module_name = f"{sub_module._get_name()}_{node.target}" unique_module_name = f"{sub_module._get_name()}_{node.target}"
outputs: onnxscript_graph_building.TorchScriptTensor | tuple[ outputs: (
onnxscript_graph_building.TorchScriptTensor, ... onnxscript_graph_building.TorchScriptTensor
] = parent_onnxscript_graph.add_module_call( # type: ignore[assignment] | tuple[onnxscript_graph_building.TorchScriptTensor, ...]
) = parent_onnxscript_graph.add_module_call( # type: ignore[assignment]
unique_module_name, sub_onnxscript_graph, onnx_args unique_module_name, sub_onnxscript_graph, onnx_args
) )

View File

@ -147,8 +147,10 @@ class FXSymbolicTracer(_exporter_legacy.FXGraphExtractor):
for v in x.values(): for v in x.values():
out += v out += v
return out return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
assert f({"a": 1, "b": 2, "c": 4}) == 7
""" """
def __init__(self, concrete_args: dict[str, Any] | None = None): def __init__(self, concrete_args: dict[str, Any] | None = None):

View File

@ -415,10 +415,9 @@ class _OnnxSchemaChecker:
inputs = (Tensor[2, 3], Tensor[2, 3]) inputs = (Tensor[2, 3], Tensor[2, 3])
attributes = {"alpha": 1.0} attributes = {"alpha": 1.0}
@torch_op("aten::op")
def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal:
...
@torch_op("aten::op")
def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal: ...
``` ```
Result: Perfect match. Result: Perfect match.

View File

@ -295,7 +295,7 @@ def _convert_torch_args_to_onnxfunction_args(
args: list[fx_type_utils.Argument], args: list[fx_type_utils.Argument],
kwargs: dict[str, fx_type_utils.Argument], kwargs: dict[str, fx_type_utils.Argument],
allow_extra_kwargs: bool = False, allow_extra_kwargs: bool = False,
) -> tuple[list[Any], dict[str, Any],]: ) -> tuple[list[Any], dict[str, Any]]:
"""Convert Python args and kwargs to OnnxFunction acceptable with matching ONNX ParamSchema. """Convert Python args and kwargs to OnnxFunction acceptable with matching ONNX ParamSchema.
NOTE: This is different from the param_schema separating in dispatcher, since at this point NOTE: This is different from the param_schema separating in dispatcher, since at this point

View File

@ -3,6 +3,7 @@
These functions should NOT be directly invoked outside of `passes` package. These functions should NOT be directly invoked outside of `passes` package.
""" """
from __future__ import annotations from __future__ import annotations
import collections import collections

View File

@ -66,9 +66,7 @@ class Decompose(_pass.Transform):
# Apply decomposition table to the input graph. # Apply decomposition table to the input graph.
assert fake_mode is not None # for mypy assert fake_mode is not None # for mypy
with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), ( with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), fake_mode:
fake_mode
):
decomposed_module = proxy_tensor.make_fx( decomposed_module = proxy_tensor.make_fx(
module, module,
decomposition_table=self.decomposition_table, decomposition_table=self.decomposition_table,

View File

@ -814,7 +814,9 @@ class Modularize(_pass.Transform):
>>> out = self.linear(out) >>> out = self.linear(out)
>>> return out >>> return out
>>> >>>
>>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)(torch.tensor([0, 1, 2])) >>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)(
... torch.tensor([0, 1, 2])
... )
>>> gm.print_readable() >>> gm.print_readable()
>>> gm = passes.Modularize(infra.DiagnosticContext("test_context", "1.0"), gm).run() >>> gm = passes.Modularize(infra.DiagnosticContext("test_context", "1.0"), gm).run()

View File

@ -76,16 +76,13 @@ class TypePromotionRule(abc.ABC):
# A class that overrides __eq__() and does not define __hash__() will have its __hash__() implicitly set to None. # A class that overrides __eq__() and does not define __hash__() will have its __hash__() implicitly set to None.
# Ref: https://docs.python.org/3/reference/datamodel.html#object.__hash__ # Ref: https://docs.python.org/3/reference/datamodel.html#object.__hash__
@abc.abstractmethod @abc.abstractmethod
def __hash__(self) -> int: def __hash__(self) -> int: ...
...
@abc.abstractmethod @abc.abstractmethod
def __repr__(self): def __repr__(self): ...
...
@abc.abstractmethod @abc.abstractmethod
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool: ...
...
def is_valid(self) -> bool: def is_valid(self) -> bool:
"""Check if the rule is valid.""" """Check if the rule is valid."""

View File

@ -95,7 +95,9 @@ class TorchExport(_exporter_legacy.FXGraphExtractor):
model = model.run_decompositions(options.decomposition_table) model = model.run_decompositions(options.decomposition_table)
# Export FX graph to ONNX ModelProto. # Export FX graph to ONNX ModelProto.
return self.pre_export_passes(options, model, model.graph_module, updated_model_args) # type: ignore[return-value] return self.pre_export_passes( # type: ignore[return-value]
options, model, model.graph_module, updated_model_args
)
def pre_export_passes( def pre_export_passes(
self, self,

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Utilities for converting and operating on ONNX, JIT and torch types.""" """Utilities for converting and operating on ONNX, JIT and torch types."""
from __future__ import annotations from __future__ import annotations
from typing import ( from typing import (
@ -31,8 +32,7 @@ if TYPE_CHECKING:
@runtime_checkable @runtime_checkable
class TensorLike(Protocol): class TensorLike(Protocol):
@property @property
def dtype(self) -> torch.dtype | None: def dtype(self) -> torch.dtype | None: ...
...
def is_torch_complex_dtype(tensor_dtype: torch.dtype) -> bool: def is_torch_complex_dtype(tensor_dtype: torch.dtype) -> bool:

View File

@ -40,8 +40,7 @@ class InputAdaptStep(Protocol):
model_args: Sequence[Any], model_args: Sequence[Any],
model_kwargs: Mapping[str, Any], model_kwargs: Mapping[str, Any],
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> tuple[Sequence[Any], Mapping[str, Any]]: ) -> tuple[Sequence[Any], Mapping[str, Any]]: ...
...
class InputAdapter: class InputAdapter:
@ -98,8 +97,7 @@ class OutputAdaptStep(Protocol):
self, self,
model_outputs: Any, model_outputs: Any,
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> Any: ) -> Any: ...
...
class OutputAdapter: class OutputAdapter:
@ -573,7 +571,8 @@ class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep):
A tuple of the model args and kwargs. A tuple of the model args and kwargs.
""" """
ordered_params = tuple( ordered_params = tuple(
model.state_dict[name] for name in model.graph_signature.parameters # type: ignore[union-attr,index] model.state_dict[name] # type: ignore[union-attr,index]
for name in model.graph_signature.parameters # type: ignore[union-attr]
) )
non_persistent_buffers = set(model.graph_signature.non_persistent_buffers) # type: ignore[union-attr] non_persistent_buffers = set(model.graph_signature.non_persistent_buffers) # type: ignore[union-attr]
ordered_buffers = [] ordered_buffers = []
@ -583,7 +582,8 @@ class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep):
else: else:
ordered_buffers.append(model.state_dict[name]) # type: ignore[union-attr,index] ordered_buffers.append(model.state_dict[name]) # type: ignore[union-attr,index]
ordered_constant_tensors = tuple( ordered_constant_tensors = tuple(
model.constants[fqn] for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr,index] model.constants[fqn] # type: ignore[union-attr,index]
for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr]
) )
# NOTE: calling convention is first params, then buffers, then args as user supplied them. # NOTE: calling convention is first params, then buffers, then args as user supplied them.

View File

@ -304,7 +304,7 @@ def _get_onnx_devices(
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
], ],
..., ...,
] ],
) -> Tuple["ORTC.OrtDevice", ...]: ) -> Tuple["ORTC.OrtDevice", ...]:
def _device_id_or_zero(device_id: int) -> int: def _device_id_or_zero(device_id: int) -> int:
return device_id or 0 return device_id or 0
@ -403,7 +403,12 @@ def _adjust_scalar_from_onnx_to_fx(
torch.SymBool, torch.SymBool,
bool, bool,
], ],
) -> Union[torch.Tensor, int, float, bool,]: ) -> Union[
torch.Tensor,
int,
float,
bool,
]:
"""Helper function to wrap ORT-produced torch.Tensor as PyTorch variables""" """Helper function to wrap ORT-produced torch.Tensor as PyTorch variables"""
assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor." assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor."
if isinstance( if isinstance(
@ -561,9 +566,9 @@ class OrtExecutionInfoPerSession:
self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices
# This is the outputs of executing the original torch.fx.GraphModule with example inputs # This is the outputs of executing the original torch.fx.GraphModule with example inputs
# (i.e., args passed into OrtBackend._ort_acclerated_call). # (i.e., args passed into OrtBackend._ort_acclerated_call).
self.example_outputs: Union[ self.example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor] = (
Tuple[torch.Tensor, ...], torch.Tensor example_outputs
] = example_outputs )
def is_supported(self, *args): def is_supported(self, *args):
# Compare the args and the input schema in ONNX model and # Compare the args and the input schema in ONNX model and

View File

@ -276,10 +276,13 @@ def onnx_symbolic(
Usage:: Usage::
``` ```
@onnx_symbolic("aten::symbolic_b", opset=10, decorate=[quantized_aten_handler(scale=1/128, zero_point=0)]) @onnx_symbolic(
"aten::symbolic_b",
opset=10,
decorate=[quantized_aten_handler(scale=1 / 128, zero_point=0)],
)
@symbolic_helper.parse_args("v", "v", "b") @symbolic_helper.parse_args("v", "v", "b")
def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value: def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value: ...
...
``` ```
Args: Args:

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
"""Utilities for converting and operating on ONNX, JIT and torch types.""" """Utilities for converting and operating on ONNX, JIT and torch types."""
from __future__ import annotations from __future__ import annotations
import enum import enum

View File

@ -1,4 +1,5 @@
"""ONNX exporter exceptions.""" """ONNX exporter exceptions."""
from __future__ import annotations from __future__ import annotations
import textwrap import textwrap

View File

@ -234,7 +234,7 @@ def parse_args(
""" """
def decorator( def decorator(
fn: Callable[_Concatenate[_U, _P], _T] fn: Callable[_Concatenate[_U, _P], _T],
) -> Callable[_Concatenate[_U, _P], _T]: ) -> Callable[_Concatenate[_U, _P], _T]:
fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined] fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined]

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type # mypy: disable-error-code=arg-type
"""This file exports ONNX ops for opset 11.""" """This file exports ONNX ops for opset 11."""
from __future__ import annotations from __future__ import annotations
import functools import functools

View File

@ -148,9 +148,7 @@ def scaled_dot_product_attention(
assert (not is_causal) or ( assert (not is_causal) or (
is_causal and symbolic_helper._is_none(attn_mask) is_causal and symbolic_helper._is_none(attn_mask)
), "is_causal and attn_mask cannot be set at the same time" ), "is_causal and attn_mask cannot be set at the same time"
assert ( assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
not enable_gqa
), "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
scale = symbolic_helper._maybe_get_const(scale, "f") scale = symbolic_helper._maybe_get_const(scale, "f")
if symbolic_helper._is_none(scale): if symbolic_helper._is_none(scale):
@ -254,7 +252,7 @@ def _causal_attention_mask(
Equivalent to:: Equivalent to::
mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_mask = torch.zeros(L, S, dtype=torch.float) attn_mask = torch.zeros(L, S, dtype=torch.float)
attn_mask = attn_mask.masked_fill(not mask, -float('inf')) attn_mask = attn_mask.masked_fill(not mask, -float("inf"))
Args: Args:
query: Tensor of shape [..., L, E] query: Tensor of shape [..., L, E]

View File

@ -56,7 +56,9 @@ def grid_sampler(
if symbolic_helper._get_tensor_rank(input) == 5: if symbolic_helper._get_tensor_rank(input) == 5:
return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input") return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg] mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg] padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg]
padding_mode_enum
]
return g.op( return g.op(
"GridSample", "GridSample",
input, input,

View File

@ -57,7 +57,9 @@ def _grid_sampler(
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index] mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html # mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
mode_s = convert_grid_sample_mode(mode_s) mode_s = convert_grid_sample_mode(mode_s)
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg, index] padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
padding_mode_enum # type: ignore[index]
]
return g.op( return g.op(
"GridSample", "GridSample",
input, input,

View File

@ -2746,7 +2746,9 @@ def native_layer_norm(
# mean and normalized, so we need to Cast it back # mean and normalized, so we need to Cast it back
if is_type_half: if is_type_half:
denominator = g.op( denominator = g.op(
"Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() # type: ignore[possibly-undefined] "Cast",
denominator,
to_i=_type_utils.JitScalarType(input_dtype).onnx_type(), # type: ignore[possibly-undefined]
) )
rdenominator = g.op("Reciprocal", denominator) rdenominator = g.op("Reciprocal", denominator)
else: else:
@ -4368,7 +4370,8 @@ def _generic_rnn(
reform_weights(g, w, hidden_size, reform_permutation) for w in weights reform_weights(g, w, hidden_size, reform_permutation) for w in weights
) )
return tuple( return tuple(
symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined] symbolic_helper._unsqueeze_helper(g, x, [0])
for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined]
) )
def transform_weights(layer_index): def transform_weights(layer_index):
@ -4498,9 +4501,10 @@ def _lstm_full(
bidirectional, bidirectional,
batch_first, batch_first,
): ):
hidden, weight = symbolic_helper._unpack_list( hidden, weight = (
hidden_v symbolic_helper._unpack_list(hidden_v),
), symbolic_helper._unpack_list(weight_v) symbolic_helper._unpack_list(weight_v),
)
return _generic_rnn( return _generic_rnn(
g, g,
"LSTM", "LSTM",
@ -4529,9 +4533,10 @@ def _lstm_packed(
train, train,
bidirectional, bidirectional,
): ):
hidden, weight = symbolic_helper._unpack_list( hidden, weight = (
hidden_v symbolic_helper._unpack_list(hidden_v),
), symbolic_helper._unpack_list(weight_v) symbolic_helper._unpack_list(weight_v),
)
return _generic_rnn( return _generic_rnn(
g, g,
"LSTM", "LSTM",

View File

@ -4,6 +4,7 @@
These models can be loaded with the ONNX library and then These models can be loaded with the ONNX library and then
converted to models which run on other deep learning frameworks. converted to models which run on other deep learning frameworks.
""" """
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
@ -224,13 +225,7 @@ def export(
3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS:: 3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS::
args = ( args = (x, {"y": input_y, "z": input_z})
x,
{
"y": input_y,
"z": input_z
}
)
All but the last element of the tuple will be passed as non-keyword arguments, All but the last element of the tuple will be passed as non-keyword arguments,
and named arguments will be set from the last element. If a named argument is and named arguments will be set from the last element. If a named argument is
@ -252,22 +247,14 @@ def export(
( (
x, x,
# WRONG: will be interpreted as named arguments # WRONG: will be interpreted as named arguments
{y: z} {y: z},
), ),
"test.onnx.pb" "test.onnx.pb",
) )
Write:: Write::
torch.onnx.export( torch.onnx.export(model, (x, {y: z}, {}), "test.onnx.pb")
model,
(
x,
{y: z},
{}
),
"test.onnx.pb"
)
f: Path to the output ONNX model file. E.g. "model.onnx". f: Path to the output ONNX model file. E.g. "model.onnx".
kwargs: Named arguments to the model. kwargs: Named arguments to the model.
@ -369,12 +356,13 @@ def export(
def forward(self, x): def forward(self, x):
return torch.sum(x, dim=1) return torch.sum(x, dim=1)
torch.onnx.export( torch.onnx.export(
SumModule(), SumModule(),
(torch.ones(2, 2),), (torch.ones(2, 2),),
"onnx.pb", "onnx.pb",
input_names=["x"], input_names=["x"],
output_names=["sum"] output_names=["sum"],
) )
Produces:: Produces::
@ -410,7 +398,7 @@ def export(
"x": {0: "my_custom_axis_name"}, "x": {0: "my_custom_axis_name"},
# list value: automatic names # list value: automatic names
"sum": [0], "sum": [0],
} },
) )
Produces:: Produces::
@ -1398,9 +1386,9 @@ def _setup_trace_module_map(
and start from the first non-numeric atom. and start from the first non-numeric atom.
Example: Example:
>>> _unqualified_variable_name('__main__.Foo.bar') >>> _unqualified_variable_name("__main__.Foo.bar")
'bar' 'bar'
>>> _unqualified_variable_name('__main__.Foo.bar.0') >>> _unqualified_variable_name("__main__.Foo.bar.0")
'bar.0' 'bar.0'
""" """
name_atoms = qualified_name.split(".") name_atoms = qualified_name.split(".")
@ -1605,7 +1593,9 @@ def _export(
if keep_initializers_as_inputs is not True: if keep_initializers_as_inputs is not True:
params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment] params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment]
graph, params_dict, getattr(model, "training", False) # type: ignore[arg-type] graph,
params_dict, # type: ignore[arg-type]
getattr(model, "training", False), # type: ignore[arg-type]
) )
_C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph) _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph)
if export_params: if export_params:
@ -1863,7 +1853,9 @@ def _run_symbolic_function(
} }
if namespace == "onnx": if namespace == "onnx":
# Clone node to trigger ONNX shape inference # Clone node to trigger ONNX shape inference
return graph_context.op(op_name, *inputs, **attrs, outputs=node.outputsSize()) # type: ignore[attr-defined] return graph_context.op(
op_name, *inputs, **attrs, outputs=node.outputsSize()
) # type: ignore[attr-defined]
raise errors.UnsupportedOperatorError( raise errors.UnsupportedOperatorError(
symbolic_function_name, symbolic_function_name,

View File

@ -217,8 +217,8 @@ def _compare_onnx_pytorch_outputs_in_np(
pt_outs: _OutputsType, pt_outs: _OutputsType,
options: VerificationOptions, options: VerificationOptions,
): ):
assert len(onnx_outs) == len( assert (
pt_outs len(onnx_outs) == len(pt_outs)
), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" ), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"
acceptable_error_percentage = options.acceptable_error_percentage acceptable_error_percentage = options.acceptable_error_percentage
if acceptable_error_percentage and ( if acceptable_error_percentage and (