diff --git a/.ci/onnx/test.sh b/.ci/onnx/test.sh index f29188c6fd5..0438779bb3c 100755 --- a/.ci/onnx/test.sh +++ b/.ci/onnx/test.sh @@ -64,7 +64,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then # TODO: change this when onnx 1.13.1 is released. pip install --no-use-pep517 'onnx @ git+https://github.com/onnx/onnx@e192ba01e438d22ca2dedd7956e28e3551626c91' # TODO: change this when onnx-script is on testPypi - pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script@0298154caf6b46fc4e30abba034095c1290c26e3' + pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script@29241e15f5182be1384f1cf6ba203d7e2e125196' # numba requires numpy <= 1.20, onnxruntime requires numpy >= 1.21. # We don't actually need it for our tests, but it's imported if it's present, so uninstall. pip uninstall -q --yes numba diff --git a/torch/onnx/_internal/fx/exporter.py b/torch/onnx/_internal/fx/exporter.py index e6193cdf501..1dcb217f9fa 100644 --- a/torch/onnx/_internal/fx/exporter.py +++ b/torch/onnx/_internal/fx/exporter.py @@ -252,9 +252,6 @@ def _export_fx_node_to_onnxscript( fx_name_to_onnxscipt_value: Dict[ str, Union[torch._C.Value, Tuple[torch._C.Value, ...]] ], - onnxscript_value_name_to_real_tensor: Dict[ - str, Union[torch.Tensor, Tuple[torch._C.Value, ...]] - ], tracer: graph_building.TorchScriptTracingEvaluator, fx_module_with_metadata: torch.fx.GraphModule, options: options.ExportOptions, @@ -388,7 +385,9 @@ def _export_fx_node_to_onnxscript( assert isinstance(input_, graph_building.TorchScriptTensor) assert isinstance(input_, onnxscript.tensor.Tensor) fx_name_to_onnxscipt_value[node.name] = input_ - onnxscript_value_name_to_real_tensor[input_.name] = current_attr # type: ignore[assignment] + # FIXME: Refactor logic getting 'current_attr'. + assert isinstance(current_attr, torch.Tensor) + onnxscript_graph.add_initializer(input_.name, current_attr) else: # TODO(wechi): Support get_attr, call_module, call_method. raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}") @@ -413,18 +412,11 @@ def _export_fx_to_onnxscript( fx_name_to_onnxscipt_value: Dict[ str, Union[torch._C.Value, Tuple[torch._C.Value, ...]] ] = {} - # Similar to fx_name_to_onnxscipt_value, we need a mapping fo real tensors (usually tensor parameters - # in nn.Module). Note that TorchScript's cannot store real tensors; TorchScript values are all - # symbolic. This is passed into ONNX ModelProto as the initializers. - onnxscript_value_name_to_real_tensor: Dict[ - str, Union[torch.Tensor, Tuple[torch._C.Value, ...]] - ] = {} for node in fx_module_with_metadata.graph.nodes: _export_fx_node_to_onnxscript( node, onnxscript_graph, fx_name_to_onnxscipt_value, - onnxscript_value_name_to_real_tensor, tracer, fx_module_with_metadata, options, @@ -439,7 +431,7 @@ def _export_fx_to_onnxscript( opset_version=options.opset_version, ) - return onnxscript_graph, onnxscript_value_name_to_real_tensor + return onnxscript_graph @_beartype.beartype @@ -531,13 +523,9 @@ def _export( # ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible # with FakeTensorMode. with torch.utils._mode_utils.no_dispatch(): - onnxscript_graph, initializers = _export_fx_to_onnxscript( - decomposed_module, export_options - ) + onnxscript_graph = _export_fx_to_onnxscript(decomposed_module, export_options) # Export TorchScript graph to ONNX ModelProto. - onnx_model = onnxscript_graph.to_model_proto( - initializers, export_options.opset_version - ) + onnx_model = onnxscript_graph.to_model_proto(export_options.opset_version) if export_options.use_binary_format: # Return ModelProto in binary format.