Revert "[ONNX] Remove the depreacated function _export (#109763)"

This reverts commit d7c05bb2e8.

Reverted https://github.com/pytorch/pytorch/pull/109763 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/109763#issuecomment-1734201053))
This commit is contained in:
PyTorch MergeBot 2023-09-25 17:47:21 +00:00
parent 52e14787ae
commit a5364b12bb
6 changed files with 22 additions and 12 deletions

View File

@ -32,7 +32,7 @@ class TestExportModes(pytorch_test_common.ExportTestCase):
torch_model = TestExportModes.MyModel()
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
f = io.BytesIO()
torch.onnx.utils._export(
torch.onnx._export(
torch_model,
(fake_input),
f,
@ -44,7 +44,7 @@ class TestExportModes(pytorch_test_common.ExportTestCase):
torch_model = TestExportModes.MyModel()
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
f = io.BytesIO()
torch.onnx.utils._export(
torch.onnx._export(
torch_model,
(fake_input),
f,
@ -56,7 +56,7 @@ class TestExportModes(pytorch_test_common.ExportTestCase):
torch_model = TestExportModes.MyModel()
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
f = io.BytesIO()
torch.onnx.utils._export(
torch.onnx._export(
torch_model,
(fake_input),
f,
@ -68,7 +68,7 @@ class TestExportModes(pytorch_test_common.ExportTestCase):
torch_model = TestExportModes.MyModel()
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
d = tempfile.mkdtemp()
torch.onnx.utils._export(
torch.onnx._export(
torch_model,
(fake_input),
d,

View File

@ -377,7 +377,7 @@ def verify(
with torch.onnx.select_model_mode_for_export(model, training):
proto_bytes = io.BytesIO()
torch_out = torch.onnx.utils._export(
torch_out = torch.onnx._export(
model,
args,
proto_bytes,
@ -397,7 +397,7 @@ def verify(
def run(args, remained_onnx_input_idx):
alt_proto_bytes = io.BytesIO()
torch_out = torch.onnx.utils._export(
torch_out = torch.onnx._export(
model,
args,
alt_proto_bytes,

View File

@ -101,7 +101,7 @@ def convert_tests(testcases, sets=1):
try:
input = gen_input(t)
f = io.BytesIO()
torch.onnx.utils._export(
torch.onnx._export(
module,
input,
f,

View File

@ -97,7 +97,7 @@ def skipIfNoEmbed(func):
def do_export(model, inputs, *args, **kwargs):
f = io.BytesIO()
out = torch.onnx.utils._export(model, inputs, f, *args, **kwargs)
out = torch.onnx._export(model, inputs, f, *args, **kwargs)
if isinstance(model, torch.jit.ScriptModule):
# Special case for common case of passing a single Tensor
if isinstance(inputs, torch.Tensor):
@ -320,7 +320,7 @@ class TestCaffe2Backend_opset9(pytorch_test_common.ExportTestCase):
# Note that the export call explicitly sets the names of not just the input,
# but also the parameters. This test checks that the model can be loaded and
# executed in Caffe2 backend correctly.
torch.onnx.utils._export(
torch.onnx._export(
model,
input,
f,
@ -353,7 +353,7 @@ class TestCaffe2Backend_opset9(pytorch_test_common.ExportTestCase):
# But note that the target first parameter name is the same as the second parameter name.
# This test checks that given this edge condition, the model can be loaded and executed
# in Caffe2 backend correctly.
torch.onnx.utils._export(
torch.onnx._export(
model,
input,
f,
@ -1613,7 +1613,7 @@ class TestCaffe2Backend_opset9(pytorch_test_common.ExportTestCase):
f = io.BytesIO()
from torch.onnx import ExportTypes
torch.onnx.utils._export(
torch.onnx._export(
MyModel(),
(torch.rand(3, 4),),
f,

View File

@ -8,6 +8,7 @@ from torch._C._onnx import (
)
from . import ( # usort:skip. Keep the order instead of sorting lexicographically
_deprecation,
errors,
symbolic_caffe2,
symbolic_helper,
@ -129,6 +130,13 @@ producer_name = "pytorch"
producer_version = _C_onnx.PRODUCER_VERSION
@_deprecation.deprecated(
since="1.12.0", removed_in="2.0", instructions="use `torch.onnx.export` instead"
)
def _export(*args, **kwargs):
return utils._export(*args, **kwargs)
# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
# Returns True iff ONNX logging is turned on.

View File

@ -5846,7 +5846,9 @@ def index(g: jit_utils.GraphContext, self, index):
if rank is None:
return symbolic_helper._unimplemented(
"aten::index",
"operator of advanced indexing on tensor of unknown rank.",
"operator of advanced indexing on tensor of unknown rank. "
"Try turning on shape inference during export: "
"torch.onnx._export(..., onnx_shape_inference=True).",
self,
)
# TODO: If indexing is supported natively in ONNX in future opsets,