mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This reverts commit ad4ccf9689.
Reverted https://github.com/pytorch/pytorch/pull/110794 on behalf of https://github.com/ezyang due to looks like this actually fails internal tests ([comment](https://github.com/pytorch/pytorch/pull/110794#issuecomment-1778002262))
98 lines
2.9 KiB
Python
98 lines
2.9 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import unittest
|
|
|
|
import torch._dynamo as torchdynamo
|
|
from torch.export import export
|
|
from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel
|
|
from torch._export.db.examples import (
|
|
filter_examples_by_support_level,
|
|
get_rewrite_cases,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
|
class ExampleTests(TestCase):
|
|
# TODO Maybe we should make this tests actually show up in a file?
|
|
@parametrize(
|
|
"name,case",
|
|
filter_examples_by_support_level(SupportLevel.SUPPORTED).items(),
|
|
name_fn=lambda name, case: f"case_{name}",
|
|
)
|
|
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
|
|
model = case.model
|
|
|
|
inputs = normalize_inputs(case.example_inputs)
|
|
exported_program = export(
|
|
model,
|
|
inputs.args,
|
|
inputs.kwargs,
|
|
dynamic_shapes=case.dynamic_shapes,
|
|
)
|
|
exported_program.graph_module.print_readable()
|
|
|
|
self.assertEqual(
|
|
exported_program(*inputs.args, **inputs.kwargs),
|
|
model(*inputs.args, **inputs.kwargs),
|
|
)
|
|
|
|
if case.extra_inputs is not None:
|
|
inputs = normalize_inputs(case.extra_inputs)
|
|
self.assertEqual(
|
|
exported_program(*inputs.args, **inputs.kwargs),
|
|
model(*inputs.args, **inputs.kwargs),
|
|
)
|
|
|
|
@parametrize(
|
|
"name,case",
|
|
filter_examples_by_support_level(SupportLevel.NOT_SUPPORTED_YET).items(),
|
|
name_fn=lambda name, case: f"case_{name}",
|
|
)
|
|
def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None:
|
|
model = case.model
|
|
# pyre-ignore
|
|
with self.assertRaises(torchdynamo.exc.Unsupported):
|
|
inputs = normalize_inputs(case.example_inputs)
|
|
exported_model = export(
|
|
model,
|
|
inputs.args,
|
|
inputs.kwargs,
|
|
dynamic_shapes=case.dynamic_shapes,
|
|
)
|
|
|
|
@parametrize(
|
|
"name,rewrite_case",
|
|
[
|
|
(name, rewrite_case)
|
|
for name, case in filter_examples_by_support_level(
|
|
SupportLevel.NOT_SUPPORTED_YET
|
|
).items()
|
|
for rewrite_case in get_rewrite_cases(case)
|
|
],
|
|
name_fn=lambda name, case: f"case_{name}_{case.name}",
|
|
)
|
|
def test_exportdb_not_supported_rewrite(
|
|
self, name: str, rewrite_case: ExportCase
|
|
) -> None:
|
|
# pyre-ignore
|
|
inputs = normalize_inputs(rewrite_case.example_inputs)
|
|
exported_model = export(
|
|
rewrite_case.model,
|
|
inputs.args,
|
|
inputs.kwargs,
|
|
dynamic_shapes=rewrite_case.dynamic_shapes,
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(ExampleTests)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|