mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Enable experimental exporter logic to dynamo_export and support refine dynamic_shapes (#134976)
(1) Enable experimental exporter logic to dynamo_export (2) Refine dynamic shapes and retry export in export strategies (3) Delete `torch_export_graph_extractor` and use the new export logic (4) Disable ExportedProgram test in `test_fx_onnx_with_onnxruntime.py`, as ONNXProgram is different now. Fixes https://github.com/pytorch/pytorch/issues/126479 Fixes #135183 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134976 Approved by: https://github.com/justinchuby
This commit is contained in:
parent
1e57ef08fa
commit
8f6e73f068
|
|
@ -148,6 +148,54 @@ class TestExportAPIDynamo(common_utils.TestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_auto_convert_all_axes_to_dynamic_shapes_with_dynamo_export(self):
|
||||||
|
os.environ["TORCH_ONNX_USE_EXPERIMENTAL_LOGIC"] = "1"
|
||||||
|
assert os.environ.get("TORCH_ONNX_USE_EXPERIMENTAL_LOGIC") == "1"
|
||||||
|
|
||||||
|
class Nested(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
(a0, a1), (b0, b1), (c0, c1, c2) = x
|
||||||
|
return a0 + a1 + b0 + b1 + c0 + c1 + c2
|
||||||
|
|
||||||
|
inputs = (
|
||||||
|
(1, 2),
|
||||||
|
(
|
||||||
|
torch.randn(4, 4),
|
||||||
|
torch.randn(4, 4),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
torch.randn(4, 4),
|
||||||
|
torch.randn(4, 4),
|
||||||
|
torch.randn(4, 4),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
onnx_program = torch.onnx.dynamo_export(
|
||||||
|
Nested(),
|
||||||
|
inputs,
|
||||||
|
export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
|
||||||
|
)
|
||||||
|
assert onnx_program is not None
|
||||||
|
onnx_testing.assert_onnx_program(onnx_program)
|
||||||
|
|
||||||
|
def test_refine_dynamic_shapes_with_onnx_export(self):
|
||||||
|
# NOTE: From test/export/test_export.py
|
||||||
|
|
||||||
|
# refine lower, upper bound
|
||||||
|
class TestRefineDynamicShapeModel(torch.nn.Module):
|
||||||
|
def forward(self, x, y):
|
||||||
|
if x.shape[0] >= 6 and y.shape[0] <= 16:
|
||||||
|
return x * 2.0, y + 1
|
||||||
|
|
||||||
|
inps = (torch.randn(16), torch.randn(12))
|
||||||
|
dynamic_shapes = {
|
||||||
|
"x": (torch.export.Dim("dx"),),
|
||||||
|
"y": (torch.export.Dim("dy"),),
|
||||||
|
}
|
||||||
|
self.assert_export(
|
||||||
|
TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
common_utils.run_tests()
|
common_utils.run_tests()
|
||||||
|
|
|
||||||
|
|
@ -346,6 +346,28 @@ def skipDtypeChecking(func):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_fake_model_and_inititalizer(reason: Optional[str] = None):
|
||||||
|
"""skip test with models using ExportedProgram as input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reason: The reason for skip the ONNX export test.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A decorator for skip tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def skip_dec(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
if kwargs["use_fake_mode"] and kwargs["include_initializer"]:
|
||||||
|
return unittest.SkipTest(reason)
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return skip_dec
|
||||||
|
|
||||||
|
|
||||||
def xfail_if_model_type_is_exportedprogram(
|
def xfail_if_model_type_is_exportedprogram(
|
||||||
error_message: str, reason: Optional[str] = None
|
error_message: str, reason: Optional[str] = None
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -544,34 +544,6 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
||||||
onnx.checker.check_model(onnx_program.model_proto)
|
onnx.checker.check_model(onnx_program.model_proto)
|
||||||
onnx.shape_inference.infer_shapes(onnx_program.model_proto)
|
onnx.shape_inference.infer_shapes(onnx_program.model_proto)
|
||||||
|
|
||||||
def test_exported_program_input_with_custom_fx_tracer(self):
|
|
||||||
from torch.onnx._internal import _exporter_legacy
|
|
||||||
from torch.onnx._internal.fx import dynamo_graph_extractor
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
return x + 1
|
|
||||||
|
|
||||||
x = torch.randn(1, 1, 2)
|
|
||||||
exported_program = torch.export.export(Model(), args=(x,))
|
|
||||||
|
|
||||||
export_options = torch.onnx.ExportOptions()
|
|
||||||
export_options = _exporter_legacy.ResolvedExportOptions(
|
|
||||||
export_options, model=exported_program
|
|
||||||
)
|
|
||||||
export_options.fx_tracer = (
|
|
||||||
dynamo_graph_extractor.DynamoExport()
|
|
||||||
) # Override fx_tracer to an unsupported tracer
|
|
||||||
with self.assertRaises(torch.onnx.OnnxExporterError):
|
|
||||||
onnx_program = torch.onnx.dynamo_export(
|
|
||||||
exported_program,
|
|
||||||
x,
|
|
||||||
export_options=export_options,
|
|
||||||
)
|
|
||||||
self.assertTrue(onnx_program._export_exception is not None)
|
|
||||||
with self.assertRaises(torch.onnx.InvalidExportOptionsError):
|
|
||||||
raise self._export_exception
|
|
||||||
|
|
||||||
def test_exported_program_torch_distributions_normal_Normal(self):
|
def test_exported_program_torch_distributions_normal_Normal(self):
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
@ -606,21 +578,6 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
||||||
# with no Cast node in between.
|
# with no Cast node in between.
|
||||||
self.assertEqual(div_node.input[0], model_proto.graph.input[0].name)
|
self.assertEqual(div_node.input[0], model_proto.graph.input[0].name)
|
||||||
|
|
||||||
def test_exported_program_as_input_with_model_signature(self):
|
|
||||||
class Model(torch.nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
return x + 1.0
|
|
||||||
|
|
||||||
x = torch.randn(1, 1, 2, dtype=torch.float)
|
|
||||||
exported_program = torch.export.export(Model(), args=(x,))
|
|
||||||
|
|
||||||
onnx_program = torch.onnx.dynamo_export(
|
|
||||||
exported_program,
|
|
||||||
x,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature)
|
|
||||||
|
|
||||||
@common_utils.parametrize(
|
@common_utils.parametrize(
|
||||||
"float8_type",
|
"float8_type",
|
||||||
[
|
[
|
||||||
|
|
@ -707,6 +664,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
||||||
onnx_program.save(tmp_onnx_file.name)
|
onnx_program.save(tmp_onnx_file.name)
|
||||||
onnx.checker.check_model(tmp_onnx_file.name, full_check=True)
|
onnx.checker.check_model(tmp_onnx_file.name, full_check=True)
|
||||||
|
|
||||||
|
@pytorch_test_common.skip_if_fake_model_and_inititalizer("segfault")
|
||||||
@common_utils.parametrize(
|
@common_utils.parametrize(
|
||||||
"include_initializer",
|
"include_initializer",
|
||||||
[
|
[
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,6 @@ def assert_op_in_onnx_model(model: onnx.ModelProto, op_type: str):
|
||||||
|
|
||||||
|
|
||||||
class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
||||||
def _test_exported_program_forces_decomposition(self, model, input, op_type):
|
|
||||||
ep = torch.export.export(model, input)
|
|
||||||
onnx_program = torch.onnx.dynamo_export(ep, *input)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
assert_op_in_onnx_model(onnx_program.model_proto, op_type)
|
|
||||||
|
|
||||||
def test_upsample_bilinear2d(self):
|
def test_upsample_bilinear2d(self):
|
||||||
class TestModel(torch.nn.Module):
|
class TestModel(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
@ -37,9 +31,6 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
||||||
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2))
|
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2))
|
||||||
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
||||||
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
||||||
self._test_exported_program_forces_decomposition(
|
|
||||||
TestModel(), (torch.randn(1, 1, 2, 2),), "Resize"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_upsample_bilinear2d_output_size(self):
|
def test_upsample_bilinear2d_output_size(self):
|
||||||
def func(x: torch.Tensor):
|
def func(x: torch.Tensor):
|
||||||
|
|
@ -61,9 +52,6 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
||||||
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2, 3))
|
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2, 3))
|
||||||
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
||||||
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
||||||
self._test_exported_program_forces_decomposition(
|
|
||||||
TestModel(), (torch.randn(1, 1, 2, 2, 3),), "Resize"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_upsample_trilinear3d_output_size(self):
|
def test_upsample_trilinear3d_output_size(self):
|
||||||
def func(x: torch.Tensor):
|
def func(x: torch.Tensor):
|
||||||
|
|
@ -82,9 +70,6 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
||||||
# If decomposition is skipped, the model will contain an InstanceNormalization op
|
# If decomposition is skipped, the model will contain an InstanceNormalization op
|
||||||
# instead of BatchNormalization op w/ training=True.
|
# instead of BatchNormalization op w/ training=True.
|
||||||
assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization")
|
assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization")
|
||||||
self._test_exported_program_forces_decomposition(
|
|
||||||
TestModel(), (torch.randn(1, 1, 2, 2),), "InstanceNormalization"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -45,10 +45,7 @@ def _parameterized_class_attrs_and_values():
|
||||||
input_values.extend(
|
input_values.extend(
|
||||||
itertools.product(
|
itertools.product(
|
||||||
(True, False),
|
(True, False),
|
||||||
(
|
(pytorch_test_common.TorchModelType.TORCH_NN_MODULE,),
|
||||||
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
|
||||||
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
|
|
@ -912,10 +909,7 @@ def _parameterized_class_attrs_and_values_with_fake_options():
|
||||||
(True, False),
|
(True, False),
|
||||||
(True, False),
|
(True, False),
|
||||||
(True, False),
|
(True, False),
|
||||||
(
|
(pytorch_test_common.TorchModelType.TORCH_NN_MODULE,),
|
||||||
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
|
||||||
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
|
|
@ -986,13 +980,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
# Create the toy model with real weight.
|
# Create the toy model with real weight.
|
||||||
real_model = create_model()
|
real_model = create_model()
|
||||||
state_dict = real_model.state_dict() # concrete (non-fake) state_dict
|
state_dict = real_model.state_dict() # concrete (non-fake) state_dict
|
||||||
if (
|
|
||||||
model_type
|
|
||||||
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
|
||||||
):
|
|
||||||
real_model = torch.export.export(
|
|
||||||
real_model, args=create_args(), kwargs=create_kwargs()
|
|
||||||
)
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(
|
with tempfile.NamedTemporaryFile(
|
||||||
prefix=model_name, suffix=".pt"
|
prefix=model_name, suffix=".pt"
|
||||||
|
|
@ -1015,13 +1002,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
)
|
)
|
||||||
|
|
||||||
if export_within_fake_mode:
|
if export_within_fake_mode:
|
||||||
if (
|
|
||||||
model_type
|
|
||||||
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
|
||||||
):
|
|
||||||
fake_model = torch.export.export(
|
|
||||||
fake_model, args=fake_args, kwargs=fake_kwargs
|
|
||||||
)
|
|
||||||
onnx_program = torch.onnx.dynamo_export(
|
onnx_program = torch.onnx.dynamo_export(
|
||||||
fake_model,
|
fake_model,
|
||||||
*fake_args,
|
*fake_args,
|
||||||
|
|
@ -1030,13 +1010,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not export_within_fake_mode:
|
if not export_within_fake_mode:
|
||||||
if (
|
|
||||||
model_type
|
|
||||||
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
|
||||||
):
|
|
||||||
fake_model = torch.export.export(
|
|
||||||
fake_model, args=fake_args, kwargs=fake_kwargs
|
|
||||||
)
|
|
||||||
onnx_program = torch.onnx.dynamo_export(
|
onnx_program = torch.onnx.dynamo_export(
|
||||||
fake_model, *fake_args, **fake_kwargs, export_options=export_options
|
fake_model, *fake_args, **fake_kwargs, export_options=export_options
|
||||||
)
|
)
|
||||||
|
|
@ -1093,10 +1066,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
|
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
|
||||||
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
|
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
|
||||||
|
|
||||||
@pytorch_test_common.skip_dynamic_fx_test(
|
|
||||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
|
||||||
)
|
|
||||||
def test_fake_tensor_mode_simple(self):
|
def test_fake_tensor_mode_simple(self):
|
||||||
def create_model() -> nn.Module:
|
def create_model() -> nn.Module:
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
|
|
@ -1126,10 +1095,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
model_type=self.model_type,
|
model_type=self.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytorch_test_common.skip_dynamic_fx_test(
|
|
||||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
|
||||||
)
|
|
||||||
@pytorch_test_common.xfail_dynamic_fx_test(
|
@pytorch_test_common.xfail_dynamic_fx_test(
|
||||||
error_message="!(it.GetName().empty())",
|
error_message="!(it.GetName().empty())",
|
||||||
reason="With after onnx==1.16, constant folding in optimizer causes this error.",
|
reason="With after onnx==1.16, constant folding in optimizer causes this error.",
|
||||||
|
|
@ -1166,10 +1131,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
model_type=self.model_type,
|
model_type=self.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytorch_test_common.skip_dynamic_fx_test(
|
|
||||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
|
||||||
)
|
|
||||||
def test_large_scale_exporter_with_toy_mlp(self):
|
def test_large_scale_exporter_with_toy_mlp(self):
|
||||||
class MLPModel(nn.Module):
|
class MLPModel(nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
@ -1208,10 +1169,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
model_type=self.model_type,
|
model_type=self.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytorch_test_common.skip_dynamic_fx_test(
|
|
||||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
|
||||||
)
|
|
||||||
def test_fake_tensor_mode_huggingface_google_t5(self):
|
def test_fake_tensor_mode_huggingface_google_t5(self):
|
||||||
config = transformers.T5Config(
|
config = transformers.T5Config(
|
||||||
vocab_size=8096, d_model=64, num_layers=2, num_heads=2
|
vocab_size=8096, d_model=64, num_layers=2, num_heads=2
|
||||||
|
|
@ -1244,10 +1201,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
model_type=self.model_type,
|
model_type=self.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytorch_test_common.skip_dynamic_fx_test(
|
|
||||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
|
||||||
)
|
|
||||||
@pytorch_test_common.xfail_dynamic_fx_test(
|
@pytorch_test_common.xfail_dynamic_fx_test(
|
||||||
error_message="scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool",
|
error_message="scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool",
|
||||||
reason="Dynamo error: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool",
|
reason="Dynamo error: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool",
|
||||||
|
|
@ -1310,10 +1263,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
model_type=self.model_type,
|
model_type=self.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytorch_test_common.skip_dynamic_fx_test(
|
|
||||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
|
||||||
)
|
|
||||||
def test_fake_tensor_mode_huggingface_mosaicml_mpt(self):
|
def test_fake_tensor_mode_huggingface_mosaicml_mpt(self):
|
||||||
config = transformers.MptConfig(
|
config = transformers.MptConfig(
|
||||||
vocab_size=8096, d_model=64, n_heads=2, n_layers=3
|
vocab_size=8096, d_model=64, n_heads=2, n_layers=3
|
||||||
|
|
@ -1341,10 +1290,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
model_type=self.model_type,
|
model_type=self.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytorch_test_common.skip_dynamic_fx_test(
|
|
||||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
|
||||||
)
|
|
||||||
@pytorch_test_common.xfail_dynamic_fx_test(
|
@pytorch_test_common.xfail_dynamic_fx_test(
|
||||||
error_message="SymIntArrayRef expected to contain only concrete integers",
|
error_message="SymIntArrayRef expected to contain only concrete integers",
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||||
|
|
@ -1374,10 +1319,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
model_type=self.model_type,
|
model_type=self.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytorch_test_common.skip_dynamic_fx_test(
|
|
||||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
|
||||||
)
|
|
||||||
@pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
|
@pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
|
||||||
error_message="Expected 5 inputs, got 3",
|
error_message="Expected 5 inputs, got 3",
|
||||||
reason="https://github.com/pytorch/pytorch/issues/115745",
|
reason="https://github.com/pytorch/pytorch/issues/115745",
|
||||||
|
|
@ -1417,10 +1358,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
model_type=self.model_type,
|
model_type=self.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytorch_test_common.skip_dynamic_fx_test(
|
|
||||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
|
||||||
)
|
|
||||||
@pytorch_test_common.xfail_dynamic_fx_test(
|
@pytorch_test_common.xfail_dynamic_fx_test(
|
||||||
error_message="SymIntArrayRef expected to contain only concrete integers",
|
error_message="SymIntArrayRef expected to contain only concrete integers",
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||||
|
|
|
||||||
|
|
@ -36,14 +36,15 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||||
torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs)
|
torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs)
|
||||||
else:
|
else:
|
||||||
torch_outputs = torch_exported_program(*input_args, **input_kwargs)
|
torch_outputs = torch_exported_program(*input_args, **input_kwargs)
|
||||||
torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx(
|
|
||||||
torch_outputs
|
if isinstance(torch_outputs, torch.Tensor):
|
||||||
)
|
torch_outputs = [torch_outputs]
|
||||||
if len(torch_outputs_onnx_format) != len(onnx_outputs):
|
|
||||||
|
if len(torch_outputs) != len(onnx_outputs):
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}"
|
f"Expected {len(torch_outputs)} outputs, got {len(onnx_outputs)}"
|
||||||
)
|
)
|
||||||
for torch_output, onnx_output in zip(torch_outputs_onnx_format, onnx_outputs):
|
for torch_output, onnx_output in zip(torch_outputs, onnx_outputs):
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
torch_output, torch.tensor(onnx_output), rtol=rtol, atol=atol
|
torch_output, torch.tensor(onnx_output), rtol=rtol, atol=atol
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ __all__ = [
|
||||||
"is_onnxrt_backend_supported",
|
"is_onnxrt_backend_supported",
|
||||||
]
|
]
|
||||||
|
|
||||||
from typing import Any, Collection, Mapping, Sequence, TYPE_CHECKING
|
from typing import Any, Callable, Collection, Mapping, Sequence, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import _C
|
from torch import _C
|
||||||
|
|
@ -112,7 +112,6 @@ from ._internal._exporter_legacy import ( # usort: skip. needs to be last to av
|
||||||
InvalidExportOptionsError,
|
InvalidExportOptionsError,
|
||||||
OnnxExporterError,
|
OnnxExporterError,
|
||||||
OnnxRegistry,
|
OnnxRegistry,
|
||||||
dynamo_export,
|
|
||||||
enable_fake_mode,
|
enable_fake_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -126,7 +125,6 @@ JitScalarType.__module__ = "torch.onnx"
|
||||||
ExportOptions.__module__ = "torch.onnx"
|
ExportOptions.__module__ = "torch.onnx"
|
||||||
ONNXProgram.__module__ = "torch.onnx"
|
ONNXProgram.__module__ = "torch.onnx"
|
||||||
ONNXRuntimeOptions.__module__ = "torch.onnx"
|
ONNXRuntimeOptions.__module__ = "torch.onnx"
|
||||||
dynamo_export.__module__ = "torch.onnx"
|
|
||||||
InvalidExportOptionsError.__module__ = "torch.onnx"
|
InvalidExportOptionsError.__module__ = "torch.onnx"
|
||||||
OnnxExporterError.__module__ = "torch.onnx"
|
OnnxExporterError.__module__ = "torch.onnx"
|
||||||
enable_fake_mode.__module__ = "torch.onnx"
|
enable_fake_mode.__module__ = "torch.onnx"
|
||||||
|
|
@ -393,6 +391,131 @@ def export(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def dynamo_export(
|
||||||
|
model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined]
|
||||||
|
/,
|
||||||
|
*model_args,
|
||||||
|
export_options: ExportOptions | None = None,
|
||||||
|
**model_kwargs,
|
||||||
|
) -> ONNXProgram | Any:
|
||||||
|
"""Export a torch.nn.Module to an ONNX graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The PyTorch model to be exported to ONNX.
|
||||||
|
model_args: Positional inputs to ``model``.
|
||||||
|
model_kwargs: Keyword inputs to ``model``.
|
||||||
|
export_options: Options to influence the export to ONNX.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An in-memory representation of the exported ONNX model.
|
||||||
|
|
||||||
|
**Example 1 - Simplest export**
|
||||||
|
::
|
||||||
|
|
||||||
|
class MyModel(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(2, 2)
|
||||||
|
|
||||||
|
def forward(self, x, bias=None):
|
||||||
|
out = self.linear(x)
|
||||||
|
out = out + bias
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
model = MyModel()
|
||||||
|
kwargs = {"bias": 3.0}
|
||||||
|
args = (torch.randn(2, 2, 2),)
|
||||||
|
onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
|
||||||
|
"my_simple_model.onnx"
|
||||||
|
)
|
||||||
|
|
||||||
|
**Example 2 - Exporting with dynamic shapes**
|
||||||
|
::
|
||||||
|
|
||||||
|
# The previous model can be exported with dynamic shapes
|
||||||
|
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||||
|
onnx_program = torch.onnx.dynamo_export(
|
||||||
|
model, *args, **kwargs, export_options=export_options
|
||||||
|
)
|
||||||
|
onnx_program.save("my_dynamic_model.onnx")
|
||||||
|
"""
|
||||||
|
|
||||||
|
# NOTE: The new exporter is experimental and is not enabled by default.
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from torch.onnx import _flags
|
||||||
|
from torch.onnx._internal import exporter
|
||||||
|
from torch.utils import _pytree
|
||||||
|
|
||||||
|
if isinstance(model, torch.export.ExportedProgram):
|
||||||
|
return exporter.export_compat(
|
||||||
|
model, # type: ignore[arg-type]
|
||||||
|
model_args,
|
||||||
|
f=None,
|
||||||
|
kwargs=model_kwargs,
|
||||||
|
opset_version=18,
|
||||||
|
external_data=True,
|
||||||
|
export_params=True,
|
||||||
|
fallback=True,
|
||||||
|
)
|
||||||
|
elif _flags.USE_EXPERIMENTAL_LOGIC:
|
||||||
|
if export_options is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"You are using an experimental ONNX export logic, which currently only supports dynamic shapes. "
|
||||||
|
"For a more comprehensive set of export options, including advanced features, please consider using "
|
||||||
|
"`torch.onnx.export(..., dynamo=True)`. ",
|
||||||
|
category=FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
if export_options is not None and export_options.dynamic_shapes:
|
||||||
|
# Make all shapes dynamic
|
||||||
|
def _to_dynamic_shapes_mapper():
|
||||||
|
arg_order = 0
|
||||||
|
|
||||||
|
def _to_dynamic_shape(x):
|
||||||
|
nonlocal arg_order
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
rank = len(x.shape)
|
||||||
|
dynamic_shape = {}
|
||||||
|
for i in range(rank):
|
||||||
|
dynamic_shape[i] = torch.export.Dim(
|
||||||
|
f"arg_{arg_order}_dim_{i}"
|
||||||
|
)
|
||||||
|
arg_order += 1
|
||||||
|
return dynamic_shape
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _to_dynamic_shape
|
||||||
|
|
||||||
|
# model_args could be nested
|
||||||
|
dynamic_shapes = _pytree.tree_map(
|
||||||
|
_to_dynamic_shapes_mapper(),
|
||||||
|
model_args,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dynamic_shapes = None
|
||||||
|
|
||||||
|
return exporter.export_compat(
|
||||||
|
model, # type: ignore[arg-type]
|
||||||
|
model_args,
|
||||||
|
f=None,
|
||||||
|
kwargs=model_kwargs,
|
||||||
|
dynamic_shapes=dynamic_shapes,
|
||||||
|
opset_version=18,
|
||||||
|
external_data=True,
|
||||||
|
export_params=True,
|
||||||
|
fallback=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from torch.onnx._internal._exporter_legacy import dynamo_export
|
||||||
|
|
||||||
|
return dynamo_export(
|
||||||
|
model, *model_args, export_options=export_options, **model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
|
# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
|
||||||
|
|
||||||
# Returns True iff ONNX logging is turned on.
|
# Returns True iff ONNX logging is turned on.
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ from typing_extensions import Self
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._ops
|
import torch._ops
|
||||||
import torch.export as torch_export
|
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from torch.onnx._internal import io_adapter
|
from torch.onnx._internal import io_adapter
|
||||||
from torch.onnx._internal.diagnostics import infra
|
from torch.onnx._internal.diagnostics import infra
|
||||||
|
|
@ -304,27 +303,17 @@ class ResolvedExportOptions(ExportOptions):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
options: ExportOptions | ResolvedExportOptions,
|
options: ExportOptions | ResolvedExportOptions,
|
||||||
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, # type: ignore[name-defined]
|
model: torch.nn.Module | Callable | None = None, # type: ignore[name-defined]
|
||||||
):
|
):
|
||||||
from torch.onnx._internal.fx import ( # TODO: Prevent circular dep
|
from torch.onnx._internal.fx import ( # TODO: Prevent circular dep
|
||||||
diagnostics,
|
diagnostics,
|
||||||
dynamo_graph_extractor,
|
dynamo_graph_extractor,
|
||||||
torch_export_graph_extractor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(options, ResolvedExportOptions):
|
if isinstance(options, ResolvedExportOptions):
|
||||||
self.dynamic_shapes = options.dynamic_shapes
|
self.dynamic_shapes = options.dynamic_shapes
|
||||||
self.diagnostic_options = options.diagnostic_options
|
self.diagnostic_options = options.diagnostic_options
|
||||||
self.fake_context = options.fake_context
|
self.fake_context = options.fake_context
|
||||||
# private
|
|
||||||
if isinstance(model, torch_export.ExportedProgram) and not isinstance(
|
|
||||||
options.fx_tracer, torch_export_graph_extractor.TorchExport
|
|
||||||
):
|
|
||||||
message = "'model' of type 'ExportedProgram' is only supported with 'TorchExport' FX Tracer"
|
|
||||||
e = InvalidExportOptionsError(message)
|
|
||||||
raise InvalidExportOptionsError(
|
|
||||||
ONNXProgram._from_failure(e, options.diagnostic_context), message
|
|
||||||
)
|
|
||||||
self.fx_tracer = options.fx_tracer
|
self.fx_tracer = options.fx_tracer
|
||||||
self.onnx_registry = options.onnx_registry
|
self.onnx_registry = options.onnx_registry
|
||||||
self.onnxfunction_dispatcher = options.onnxfunction_dispatcher
|
self.onnxfunction_dispatcher = options.onnxfunction_dispatcher
|
||||||
|
|
@ -345,10 +334,8 @@ class ResolvedExportOptions(ExportOptions):
|
||||||
self.diagnostic_options = resolve(
|
self.diagnostic_options = resolve(
|
||||||
options.diagnostic_options, DiagnosticOptions()
|
options.diagnostic_options, DiagnosticOptions()
|
||||||
)
|
)
|
||||||
if isinstance(model, torch_export.ExportedProgram):
|
|
||||||
self.fx_tracer = torch_export_graph_extractor.TorchExport()
|
self.fx_tracer = dynamo_graph_extractor.DynamoExport()
|
||||||
else:
|
|
||||||
self.fx_tracer = dynamo_graph_extractor.DynamoExport()
|
|
||||||
|
|
||||||
self.fake_context = resolve(options.fake_context, None) # type: ignore[arg-type]
|
self.fake_context = resolve(options.fake_context, None) # type: ignore[arg-type]
|
||||||
self.diagnostic_context = diagnostics.DiagnosticContext(
|
self.diagnostic_context = diagnostics.DiagnosticContext(
|
||||||
|
|
@ -492,7 +479,6 @@ class ONNXProgram:
|
||||||
diagnostic_context: Context object for the SARIF diagnostic system responsible for logging errors and metadata.
|
diagnostic_context: Context object for the SARIF diagnostic system responsible for logging errors and metadata.
|
||||||
fake_context: The fake context used for symbolic tracing.
|
fake_context: The fake context used for symbolic tracing.
|
||||||
export_exception: The exception that occurred during export, if any.
|
export_exception: The exception that occurred during export, if any.
|
||||||
model_signature: The model signature for the exported ONNX graph.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_model_proto: Final[onnx.ModelProto] # type: ignore[name-defined, misc]
|
_model_proto: Final[onnx.ModelProto] # type: ignore[name-defined, misc]
|
||||||
|
|
@ -501,9 +487,8 @@ class ONNXProgram:
|
||||||
_diagnostic_context: Final[diagnostics.DiagnosticContext] # type: ignore[misc]
|
_diagnostic_context: Final[diagnostics.DiagnosticContext] # type: ignore[misc]
|
||||||
_fake_context: Final[ONNXFakeContext | None] # type: ignore[misc]
|
_fake_context: Final[ONNXFakeContext | None] # type: ignore[misc]
|
||||||
_export_exception: Final[Exception | None] # type: ignore[misc]
|
_export_exception: Final[Exception | None] # type: ignore[misc]
|
||||||
_model_signature: Final[torch.export.ExportGraphSignature | None] # type: ignore[misc]
|
|
||||||
_model_torch: Final[ # type: ignore[misc]
|
_model_torch: Final[ # type: ignore[misc]
|
||||||
torch.nn.Module | Callable | torch_export.ExportedProgram | None
|
torch.nn.Module | Callable | None
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -515,14 +500,9 @@ class ONNXProgram:
|
||||||
*,
|
*,
|
||||||
fake_context: ONNXFakeContext | None = None,
|
fake_context: ONNXFakeContext | None = None,
|
||||||
export_exception: Exception | None = None,
|
export_exception: Exception | None = None,
|
||||||
model_signature: torch.export.ExportGraphSignature | None = None,
|
model_torch: torch.nn.Module | Callable | None = None,
|
||||||
model_torch: torch.nn.Module
|
|
||||||
| Callable
|
|
||||||
| torch_export.ExportedProgram
|
|
||||||
| None = None,
|
|
||||||
):
|
):
|
||||||
self._model_proto = model_proto
|
self._model_proto = model_proto
|
||||||
self._model_signature = model_signature
|
|
||||||
self._model_torch = model_torch
|
self._model_torch = model_torch
|
||||||
self._input_adapter = input_adapter
|
self._input_adapter = input_adapter
|
||||||
self._output_adapter = output_adapter
|
self._output_adapter = output_adapter
|
||||||
|
|
@ -533,10 +513,7 @@ class ONNXProgram:
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
model_with_state_dict: torch.nn.Module
|
model_with_state_dict: torch.nn.Module | Callable | None = None,
|
||||||
| Callable
|
|
||||||
| torch_export.ExportedProgram
|
|
||||||
| None = None,
|
|
||||||
options: ONNXRuntimeOptions | None = None,
|
options: ONNXRuntimeOptions | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
|
@ -571,8 +548,6 @@ class ONNXProgram:
|
||||||
onnx_model = os.path.join(tmpdir_path, "model.onnx")
|
onnx_model = os.path.join(tmpdir_path, "model.onnx")
|
||||||
if isinstance(model_with_state_dict, torch.nn.Module):
|
if isinstance(model_with_state_dict, torch.nn.Module):
|
||||||
model_state = model_with_state_dict.state_dict()
|
model_state = model_with_state_dict.state_dict()
|
||||||
elif isinstance(model_with_state_dict, torch_export.ExportedProgram):
|
|
||||||
model_state = model_with_state_dict.state_dict
|
|
||||||
else:
|
else:
|
||||||
model_state = None
|
model_state = None
|
||||||
self.save(
|
self.save(
|
||||||
|
|
@ -608,104 +583,6 @@ class ONNXProgram:
|
||||||
raise self._export_exception
|
raise self._export_exception
|
||||||
return self._model_proto
|
return self._model_proto
|
||||||
|
|
||||||
@property
|
|
||||||
def model_signature(self) -> torch.export.ExportGraphSignature | None:
|
|
||||||
"""The model signature for the exported ONNX graph.
|
|
||||||
|
|
||||||
This information is relevant because ONNX specification often differs from PyTorch's, resulting
|
|
||||||
in a ONNX graph with input and output schema different from the actual PyTorch model implementation.
|
|
||||||
By using the model signature, the users can understand the inputs and outputs differences
|
|
||||||
and properly execute the model in ONNX Runtime.
|
|
||||||
|
|
||||||
NOTE: Model signature is only available when the ONNX graph was exported from a
|
|
||||||
:class:`torch.export.ExportedProgram` object.
|
|
||||||
|
|
||||||
NOTE: Any transformation done to the model that changes the model signature must be accompanied
|
|
||||||
by updates to this model signature as well through :class:`InputAdaptStep` and/or :class:`OutputAdaptStep`.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
The following model produces different sets of inputs and outputs.
|
|
||||||
The first 4 inputs are model parameters (namely conv1.weight, conv2.weight, fc1.weight, fc2.weight),
|
|
||||||
and the next 2 inputs are registered buffers (namely my_buffer2, my_buffer1) and finally
|
|
||||||
the last 2 inputs are user inputs (namely x and b).
|
|
||||||
The first output is a buffer mutation (namely my_buffer2) and the last output is the actual model output.
|
|
||||||
|
|
||||||
>>> import pprint
|
|
||||||
>>> class CustomModule(torch.nn.Module):
|
|
||||||
... def __init__(self) -> None:
|
|
||||||
... super().__init__()
|
|
||||||
... self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
|
|
||||||
... self.register_buffer("my_buffer1", torch.tensor(3.0))
|
|
||||||
... self.register_buffer("my_buffer2", torch.tensor(4.0))
|
|
||||||
... self.conv1 = torch.nn.Conv2d(1, 32, 3, 1, bias=False)
|
|
||||||
... self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, bias=False)
|
|
||||||
... self.fc1 = torch.nn.Linear(9216, 128, bias=False)
|
|
||||||
... self.fc2 = torch.nn.Linear(128, 10, bias=False)
|
|
||||||
...
|
|
||||||
... def forward(self, x, b):
|
|
||||||
... tensor_x = self.conv1(x)
|
|
||||||
... tensor_x = torch.nn.functional.sigmoid(tensor_x)
|
|
||||||
... tensor_x = self.conv2(tensor_x)
|
|
||||||
... tensor_x = torch.nn.functional.sigmoid(tensor_x)
|
|
||||||
... tensor_x = torch.nn.functional.max_pool2d(tensor_x, 2)
|
|
||||||
... tensor_x = torch.flatten(tensor_x, 1)
|
|
||||||
... tensor_x = self.fc1(tensor_x)
|
|
||||||
... tensor_x = torch.nn.functional.sigmoid(tensor_x)
|
|
||||||
... tensor_x = self.fc2(tensor_x)
|
|
||||||
... output = torch.nn.functional.log_softmax(tensor_x, dim=1)
|
|
||||||
... (
|
|
||||||
... self.my_buffer2.add_(1.0) + self.my_buffer1
|
|
||||||
... ) # Mutate buffer through in-place addition
|
|
||||||
... return output
|
|
||||||
>>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3))
|
|
||||||
>>> exported_program = torch.export.export(
|
|
||||||
... CustomModule(), args=inputs
|
|
||||||
... ).run_decompositions({})
|
|
||||||
>>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs)
|
|
||||||
>>> pprint.pprint(onnx_program.model_signature)
|
|
||||||
ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>,
|
|
||||||
arg=TensorArgument(name='p_conv1_weight'),
|
|
||||||
target='conv1.weight',
|
|
||||||
persistent=None),
|
|
||||||
InputSpec(kind=<InputKind.PARAMETER: 2>,
|
|
||||||
arg=TensorArgument(name='p_conv2_weight'),
|
|
||||||
target='conv2.weight',
|
|
||||||
persistent=None),
|
|
||||||
InputSpec(kind=<InputKind.PARAMETER: 2>,
|
|
||||||
arg=TensorArgument(name='p_fc1_weight'),
|
|
||||||
target='fc1.weight',
|
|
||||||
persistent=None),
|
|
||||||
InputSpec(kind=<InputKind.PARAMETER: 2>,
|
|
||||||
arg=TensorArgument(name='p_fc2_weight'),
|
|
||||||
target='fc2.weight',
|
|
||||||
persistent=None),
|
|
||||||
InputSpec(kind=<InputKind.BUFFER: 3>,
|
|
||||||
arg=TensorArgument(name='b_my_buffer2'),
|
|
||||||
target='my_buffer2',
|
|
||||||
persistent=True),
|
|
||||||
InputSpec(kind=<InputKind.BUFFER: 3>,
|
|
||||||
arg=TensorArgument(name='b_my_buffer1'),
|
|
||||||
target='my_buffer1',
|
|
||||||
persistent=True),
|
|
||||||
InputSpec(kind=<InputKind.USER_INPUT: 1>,
|
|
||||||
arg=TensorArgument(name='x'),
|
|
||||||
target=None,
|
|
||||||
persistent=None),
|
|
||||||
InputSpec(kind=<InputKind.USER_INPUT: 1>,
|
|
||||||
arg=TensorArgument(name='b'),
|
|
||||||
target=None,
|
|
||||||
persistent=None)],
|
|
||||||
output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>,
|
|
||||||
arg=TensorArgument(name='add'),
|
|
||||||
target='my_buffer2'),
|
|
||||||
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>,
|
|
||||||
arg=TensorArgument(name='_log_softmax'),
|
|
||||||
target=None)])
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self._model_signature
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def diagnostic_context(self) -> diagnostics.DiagnosticContext:
|
def diagnostic_context(self) -> diagnostics.DiagnosticContext:
|
||||||
"""The diagnostic context associated with the export."""
|
"""The diagnostic context associated with the export."""
|
||||||
|
|
@ -721,10 +598,7 @@ class ONNXProgram:
|
||||||
def adapt_torch_inputs_to_onnx(
|
def adapt_torch_inputs_to_onnx(
|
||||||
self,
|
self,
|
||||||
*model_args,
|
*model_args,
|
||||||
model_with_state_dict: torch.nn.Module
|
model_with_state_dict: torch.nn.Module | Callable | None = None,
|
||||||
| Callable
|
|
||||||
| torch_export.ExportedProgram
|
|
||||||
| None = None,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Sequence[torch.Tensor | int | float | bool | torch.dtype]:
|
) -> Sequence[torch.Tensor | int | float | bool | torch.dtype]:
|
||||||
"""Converts the PyTorch model inputs to exported ONNX model inputs format.
|
"""Converts the PyTorch model inputs to exported ONNX model inputs format.
|
||||||
|
|
@ -794,10 +668,7 @@ class ONNXProgram:
|
||||||
def adapt_torch_outputs_to_onnx(
|
def adapt_torch_outputs_to_onnx(
|
||||||
self,
|
self,
|
||||||
model_outputs: Any,
|
model_outputs: Any,
|
||||||
model_with_state_dict: torch.nn.Module
|
model_with_state_dict: torch.nn.Module | Callable | None = None,
|
||||||
| Callable
|
|
||||||
| torch_export.ExportedProgram
|
|
||||||
| None = None,
|
|
||||||
) -> Sequence[torch.Tensor | int | float | bool]:
|
) -> Sequence[torch.Tensor | int | float | bool]:
|
||||||
"""Converts the PyTorch model outputs to exported ONNX model outputs format.
|
"""Converts the PyTorch model outputs to exported ONNX model outputs format.
|
||||||
|
|
||||||
|
|
@ -1050,7 +921,7 @@ class Exporter:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
options: ResolvedExportOptions,
|
options: ResolvedExportOptions,
|
||||||
model: torch.nn.Module | Callable | torch_export.ExportedProgram,
|
model: torch.nn.Module | Callable,
|
||||||
model_args: Sequence[Any],
|
model_args: Sequence[Any],
|
||||||
model_kwargs: Mapping[str, Any],
|
model_kwargs: Mapping[str, Any],
|
||||||
):
|
):
|
||||||
|
|
@ -1138,9 +1009,6 @@ class Exporter:
|
||||||
self.options.fx_tracer.output_adapter,
|
self.options.fx_tracer.output_adapter,
|
||||||
self.options.diagnostic_context,
|
self.options.diagnostic_context,
|
||||||
fake_context=self.options.fake_context,
|
fake_context=self.options.fake_context,
|
||||||
model_signature=getattr(
|
|
||||||
self.model, "graph_signature", None
|
|
||||||
), # Available for isinstance(self.model, ExportedProgram) only
|
|
||||||
model_torch=self.model,
|
model_torch=self.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1261,12 +1129,12 @@ def _assert_dependencies(export_options: ResolvedExportOptions):
|
||||||
|
|
||||||
|
|
||||||
def dynamo_export(
|
def dynamo_export(
|
||||||
model: torch.nn.Module | Callable | torch_export.ExportedProgram, # type: ignore[name-defined]
|
model: torch.nn.Module | Callable,
|
||||||
/,
|
/,
|
||||||
*model_args,
|
*model_args,
|
||||||
export_options: ExportOptions | None = None,
|
export_options: ExportOptions | None = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> ONNXProgram:
|
) -> ONNXProgram | Any:
|
||||||
"""Export a torch.nn.Module to an ONNX graph.
|
"""Export a torch.nn.Module to an ONNX graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -120,9 +120,20 @@ class TorchExportStrategy(CaptureStrategy):
|
||||||
def _capture(
|
def _capture(
|
||||||
self, model, args, kwargs, dynamic_shapes
|
self, model, args, kwargs, dynamic_shapes
|
||||||
) -> torch.export.ExportedProgram:
|
) -> torch.export.ExportedProgram:
|
||||||
return torch.export.export(
|
try:
|
||||||
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
|
return torch.export.export(
|
||||||
)
|
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
|
||||||
|
)
|
||||||
|
except torch._dynamo.exc.UserError as exc:
|
||||||
|
# Refine the dynamic shapes based on the suggested fixes.
|
||||||
|
new_shapes = (
|
||||||
|
torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(
|
||||||
|
exc.msg, dynamic_shapes
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return torch.export.export(
|
||||||
|
model, args, kwargs=kwargs, dynamic_shapes=new_shapes
|
||||||
|
)
|
||||||
|
|
||||||
def _enter(self, model) -> None:
|
def _enter(self, model) -> None:
|
||||||
model_repr = _take_first_line(repr(model))
|
model_repr = _take_first_line(repr(model))
|
||||||
|
|
@ -148,9 +159,20 @@ class TorchExportNonStrictStrategy(CaptureStrategy):
|
||||||
def _capture(
|
def _capture(
|
||||||
self, model, args, kwargs, dynamic_shapes
|
self, model, args, kwargs, dynamic_shapes
|
||||||
) -> torch.export.ExportedProgram:
|
) -> torch.export.ExportedProgram:
|
||||||
return torch.export.export(
|
try:
|
||||||
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False
|
return torch.export.export(
|
||||||
)
|
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False
|
||||||
|
)
|
||||||
|
except torch._dynamo.exc.UserError as exc:
|
||||||
|
# Refine the dynamic shapes based on the suggested fixes.
|
||||||
|
new_shapes = (
|
||||||
|
torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(
|
||||||
|
exc.msg, dynamic_shapes
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return torch.export.export(
|
||||||
|
model, args, kwargs=kwargs, dynamic_shapes=new_shapes, strict=False
|
||||||
|
)
|
||||||
|
|
||||||
def _enter(self, model) -> None:
|
def _enter(self, model) -> None:
|
||||||
model_repr = _take_first_line(repr(model))
|
model_repr = _take_first_line(repr(model))
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ import logging
|
||||||
from typing import Any, Mapping, Sequence, TYPE_CHECKING
|
from typing import Any, Mapping, Sequence, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.export
|
|
||||||
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
|
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
|
||||||
from torch.onnx._internal.exporter import _core, _onnx_program
|
from torch.onnx._internal.exporter import _core, _onnx_program
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -176,7 +176,7 @@ class Transform(abc.ABC):
|
||||||
|
|
||||||
One important aspect to note is that if the transformation modifies the model input and/or output signature,
|
One important aspect to note is that if the transformation modifies the model input and/or output signature,
|
||||||
(e.g. additional inputs/outputs are added to the model), :class:`InputAdaptStep` and/or :class:`OutputAdaptStep`
|
(e.g. additional inputs/outputs are added to the model), :class:`InputAdaptStep` and/or :class:`OutputAdaptStep`
|
||||||
are needed to reconcile :attr:`ONNXProgram.model_signature` and :attr:`ONNXProgram.model_proto`.
|
are needed to reconcile :attr:`ONNXProgram.model_proto`.
|
||||||
That is, the model signature and the model representation must match.
|
That is, the model signature and the model representation must match.
|
||||||
|
|
||||||
As an additional feature, this class provides builtin support for transformation recording using the diagnostics.
|
As an additional feature, this class provides builtin support for transformation recording using the diagnostics.
|
||||||
|
|
|
||||||
|
|
@ -1,128 +0,0 @@
|
||||||
# mypy: allow-untyped-defs
|
|
||||||
# NOTE: This file is referenced by name at
|
|
||||||
# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES.
|
|
||||||
# introduced by https://github.com/pytorch/pytorch/pull/98894.
|
|
||||||
# If this file is renamed, moved, etc please update the reference there!
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING
|
|
||||||
|
|
||||||
import torch._dynamo
|
|
||||||
import torch.fx
|
|
||||||
from torch.onnx._internal import _exporter_legacy, io_adapter
|
|
||||||
from torch.onnx._internal.diagnostics import infra
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import torch.onnx
|
|
||||||
from torch.export.exported_program import ExportedProgram
|
|
||||||
|
|
||||||
|
|
||||||
class TorchExport(_exporter_legacy.FXGraphExtractor):
|
|
||||||
"""Generates a FX GraphModule using torch.export API
|
|
||||||
Args:
|
|
||||||
aten_graph: If True, exports a graph with ATen operators.
|
|
||||||
If False, exports a graph with Python operators.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
aten_graph: bool | None = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.aten_graph = aten_graph or True
|
|
||||||
|
|
||||||
def generate_fx(
|
|
||||||
self,
|
|
||||||
options: _exporter_legacy.ResolvedExportOptions,
|
|
||||||
model: ExportedProgram, # type: ignore[override]
|
|
||||||
model_args: Sequence[Any],
|
|
||||||
model_kwargs: Mapping[str, Any],
|
|
||||||
) -> torch.fx.GraphModule:
|
|
||||||
# No need to translate callable to FX graph.
|
|
||||||
# This FX Graph extractor assumes `model` was obtained through
|
|
||||||
# exported_program = torch.export.export(
|
|
||||||
# model,
|
|
||||||
# args=model_args, # type: ignore[arg-type]
|
|
||||||
# kwargs=model_kwargs, # type: ignore[arg-type]
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Export FX graph to ONNX ModelProto.
|
|
||||||
self.input_adapter.append_step(
|
|
||||||
io_adapter.FlattenInputWithTreeSpecValidationInputStep()
|
|
||||||
)
|
|
||||||
self.input_adapter.append_step(
|
|
||||||
io_adapter.PrependParamsBuffersConstantAotAutogradInputStep()
|
|
||||||
)
|
|
||||||
|
|
||||||
# ONNX does not support None inputs. During graph building, all None inputs
|
|
||||||
# are removed. Here we register this step to input adapter.
|
|
||||||
options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep())
|
|
||||||
|
|
||||||
# NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534
|
|
||||||
# Dynamo doesn't support non-tensor inputs.
|
|
||||||
options.fx_tracer.input_adapter.append_step(
|
|
||||||
io_adapter.RemoveNonTensorInputStep()
|
|
||||||
)
|
|
||||||
|
|
||||||
# ONNX does not support complex inputs. During graph building, all complex inputs
|
|
||||||
# are converted to real representation inputs. Here we register this step to
|
|
||||||
# input/output adapter.
|
|
||||||
options.fx_tracer.input_adapter.append_step(
|
|
||||||
io_adapter.ConvertComplexToRealRepresentationInputStep()
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_model_args = self.input_adapter.apply(
|
|
||||||
*model_args, model=model, **model_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# ONNX can't represent collection types (e.g., dictionary, tuple of tuple of
|
|
||||||
# tensor, etc), we flatten the collection and register each element as output.
|
|
||||||
options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep())
|
|
||||||
|
|
||||||
# Output post-processing steps should happen after `FlattenOutputStep`.
|
|
||||||
options.fx_tracer.output_adapter.append_step(
|
|
||||||
io_adapter.ConvertComplexToRealRepresentationOutputStep()
|
|
||||||
)
|
|
||||||
|
|
||||||
options.fx_tracer.output_adapter.append_step(
|
|
||||||
io_adapter.PrependParamsAndBuffersAotAutogradOutputStep()
|
|
||||||
)
|
|
||||||
|
|
||||||
# run_decomposition generates a new graph module with decomposed ops.
|
|
||||||
# Thus, we need to run this step after io_adapters.
|
|
||||||
model = model.run_decompositions(options.decomposition_table)
|
|
||||||
|
|
||||||
# Export FX graph to ONNX ModelProto.
|
|
||||||
return self.pre_export_passes( # type: ignore[return-value]
|
|
||||||
options, model, model.graph_module, updated_model_args
|
|
||||||
)
|
|
||||||
|
|
||||||
def pre_export_passes(
|
|
||||||
self,
|
|
||||||
options: _exporter_legacy.ResolvedExportOptions,
|
|
||||||
original_model: torch.nn.Module | Callable,
|
|
||||||
fx_module: torch.fx.GraphModule,
|
|
||||||
fx_module_args: Sequence[Any],
|
|
||||||
):
|
|
||||||
# TODO: Import here to prevent circular dependency
|
|
||||||
from torch.onnx._internal.fx import analysis, passes
|
|
||||||
|
|
||||||
diagnostic_context = options.diagnostic_context
|
|
||||||
|
|
||||||
# ONNX does not support concept of (implicit) type promotion.
|
|
||||||
# Insert type casts explicitly where needed.
|
|
||||||
fx_module = passes.InsertTypePromotion(diagnostic_context, fx_module).run()
|
|
||||||
|
|
||||||
analysis.UnsupportedFxNodesAnalysis(
|
|
||||||
diagnostic_context, fx_module, options.onnxfunction_dispatcher
|
|
||||||
).analyze(infra.levels.ERROR)
|
|
||||||
|
|
||||||
# This operation should be invoked as the last pre export pass.
|
|
||||||
# See [NOTE: Modularize pass ordering]
|
|
||||||
fx_module = passes.Modularize(
|
|
||||||
diagnostic_context, fx_module, is_exported_program=True
|
|
||||||
).run()
|
|
||||||
|
|
||||||
return fx_module
|
|
||||||
Loading…
Reference in New Issue
Block a user