pytorch/test/export/test_db.py
ydwu4 6abb8c382c [export] add kwargs support for export. (#105337)
Solving #105242.

During export, the exported function's signature changes multiple times. Suppose we'd like to export f as shown in following example:
```python
def f(arg1, arg2, kw1, kw2):
  pass

args = (arg1, arg2)
kwargs =  {"kw2":arg3, "kw1":arg4}

torch.export(f, args, kwargs)
```
The signature changes mutiple times during export process in the following order:
1. **gm_torch_level = dynamo.export(f, *args, \*\*kwargs)**. In this step, we turn all  kinds of parameters such as **postional_only**, **var_positioinal**, **kw_only**, and **var_kwargs** into **positional_or_kw**.It also preserves the positional and kword argument names in original function (i.e. f in this example) [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/export.py#L546C13-L546C27). The order of kwargs will be the **key order** of kwargs (after python 3.6, the order is the insertion of order of keys) instead of the original function signature and the order is baked into a _orig_args varaible of gm_torch_level's pytree info. So we'll have:
```python
def gm_torch_level(arg1, arg2, kw2, kw1)
```
Such difference is acceptable as it's transparent to users of export.

2. **gm_aot_export = aot_export_module(gm_torch_level, pos_or_kw_args)**. In this step, we need to turn kwargs into positional args in the order of how gm_torch_level expected, which is stored in _orig_args. The returned gm_aot_export has the graph signature of flat_args, in_spec = pytree.tree_flatten(pos_or_kw_args):
``` python
flat_args, _ = pytree.tree_flatten(pos_or_kw_args)
def gm_aot_export(*flat_args)
```

3. **exported_program(*args, \*\*kwargs)**. The epxorted artifact is exported_program, which is a wrapper over gm_aot_export and has the same calling convention as the original function "f". To do this, we need to 1. specialize the order of kwargs into pos_or_kw_args and 2. flatten the pos_or_kw_args into what gm_aot_export expected.  We can combine the two steps into one with :
```python
_, in_spec = pytree.tree_flatten((args, kwargs))

# Then during exported_program.__call__(*args, **kwargs)
flat_args  = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
```
, where kwargs is treated as a normal pytree whose keyorder is preserved in in_spec.

Implementation-wise, we treat _orig_args in dynamo exported graph module as single source of truth and kwags are ordered following it.

Test plan:
See added tests in test_export.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105337
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
2023-07-20 19:53:08 +00:00

98 lines
2.9 KiB
Python

# Owner(s): ["module: dynamo"]
import unittest
import torch._dynamo as torchdynamo
from torch._export import export
from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel
from torch._export.db.examples import (
filter_examples_by_support_level,
get_rewrite_cases,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class ExampleTests(TestCase):
# TODO Maybe we should make this tests actually show up in a file?
@parametrize(
"name,case",
filter_examples_by_support_level(SupportLevel.SUPPORTED).items(),
name_fn=lambda name, case: f"case_{name}",
)
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
model = case.model
inputs = normalize_inputs(case.example_inputs)
exported_program = export(
model,
inputs.args,
inputs.kwargs,
constraints=case.constraints,
)
exported_program.graph_module.print_readable()
self.assertEqual(
exported_program(*inputs.args, **inputs.kwargs),
model(*inputs.args, **inputs.kwargs),
)
if case.extra_inputs is not None:
inputs = normalize_inputs(case.extra_inputs)
self.assertEqual(
exported_program(*inputs.args, **inputs.kwargs),
model(*inputs.args, **inputs.kwargs),
)
@parametrize(
"name,case",
filter_examples_by_support_level(SupportLevel.NOT_SUPPORTED_YET).items(),
name_fn=lambda name, case: f"case_{name}",
)
def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None:
model = case.model
# pyre-ignore
with self.assertRaises(torchdynamo.exc.Unsupported):
inputs = normalize_inputs(case.example_inputs)
exported_model = export(
model,
inputs.args,
inputs.kwargs,
constraints=case.constraints,
)
@parametrize(
"name,rewrite_case",
[
(name, rewrite_case)
for name, case in filter_examples_by_support_level(
SupportLevel.NOT_SUPPORTED_YET
).items()
for rewrite_case in get_rewrite_cases(case)
],
name_fn=lambda name, case: f"case_{name}_{case.name}",
)
def test_exportdb_not_supported_rewrite(
self, name: str, rewrite_case: ExportCase
) -> None:
# pyre-ignore
inputs = normalize_inputs(rewrite_case.example_inputs)
exported_model = export(
rewrite_case.model,
inputs.args,
inputs.kwargs,
constraints=rewrite_case.constraints,
)
instantiate_parametrized_tests(ExampleTests)
if __name__ == "__main__":
run_tests()