mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Rebase of #111313 onto `main`, for CI validation Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/115969 Approved by: https://github.com/PaliC, https://github.com/thiagocrepaldi
103 lines
3.3 KiB
Python
103 lines
3.3 KiB
Python
# Owner(s): ["oncall: export"]
|
|
|
|
import copy
|
|
import unittest
|
|
|
|
import torch._dynamo as torchdynamo
|
|
from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel
|
|
from torch._export.db.examples import (
|
|
filter_examples_by_support_level,
|
|
get_rewrite_cases,
|
|
)
|
|
from torch.export import export
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
TestCase,
|
|
IS_WINDOWS
|
|
)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
|
class ExampleTests(TestCase):
|
|
# TODO Maybe we should make this tests actually show up in a file?
|
|
@parametrize(
|
|
"name,case",
|
|
filter_examples_by_support_level(SupportLevel.SUPPORTED).items(),
|
|
name_fn=lambda name, case: f"case_{name}",
|
|
)
|
|
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
|
|
model = case.model
|
|
|
|
inputs_export = normalize_inputs(case.example_inputs)
|
|
inputs_model = copy.deepcopy(inputs_export)
|
|
exported_program = export(
|
|
model,
|
|
inputs_export.args,
|
|
inputs_export.kwargs,
|
|
dynamic_shapes=case.dynamic_shapes,
|
|
)
|
|
exported_program.graph_module.print_readable()
|
|
|
|
self.assertEqual(
|
|
exported_program.module()(*inputs_export.args, **inputs_export.kwargs),
|
|
model(*inputs_model.args, **inputs_model.kwargs),
|
|
)
|
|
|
|
if case.extra_inputs is not None:
|
|
inputs = normalize_inputs(case.extra_inputs)
|
|
self.assertEqual(
|
|
exported_program.module()(*inputs.args, **inputs.kwargs),
|
|
model(*inputs.args, **inputs.kwargs),
|
|
)
|
|
|
|
@parametrize(
|
|
"name,case",
|
|
filter_examples_by_support_level(SupportLevel.NOT_SUPPORTED_YET).items(),
|
|
name_fn=lambda name, case: f"case_{name}",
|
|
)
|
|
def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None:
|
|
model = case.model
|
|
# pyre-ignore
|
|
with self.assertRaises((torchdynamo.exc.Unsupported, AssertionError, RuntimeError)):
|
|
inputs = normalize_inputs(case.example_inputs)
|
|
exported_model = export(
|
|
model,
|
|
inputs.args,
|
|
inputs.kwargs,
|
|
dynamic_shapes=case.dynamic_shapes,
|
|
)
|
|
|
|
exportdb_not_supported_rewrite_cases = [
|
|
(name, rewrite_case)
|
|
for name, case in filter_examples_by_support_level(
|
|
SupportLevel.NOT_SUPPORTED_YET
|
|
).items()
|
|
for rewrite_case in get_rewrite_cases(case)
|
|
]
|
|
if exportdb_not_supported_rewrite_cases:
|
|
@parametrize(
|
|
"name,rewrite_case",
|
|
exportdb_not_supported_rewrite_cases,
|
|
name_fn=lambda name, case: f"case_{name}_{case.name}",
|
|
)
|
|
def test_exportdb_not_supported_rewrite(
|
|
self, name: str, rewrite_case: ExportCase
|
|
) -> None:
|
|
# pyre-ignore
|
|
inputs = normalize_inputs(rewrite_case.example_inputs)
|
|
exported_model = export(
|
|
rewrite_case.model,
|
|
inputs.args,
|
|
inputs.kwargs,
|
|
dynamic_shapes=rewrite_case.dynamic_shapes,
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(ExampleTests)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|