mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Solving #105242. During export, the exported function's signature changes multiple times. Suppose we'd like to export f as shown in following example: ```python def f(arg1, arg2, kw1, kw2): pass args = (arg1, arg2) kwargs = {"kw2":arg3, "kw1":arg4} torch.export(f, args, kwargs) ``` The signature changes mutiple times during export process in the following order: 1. **gm_torch_level = dynamo.export(f, *args, \*\*kwargs)**. In this step, we turn all kinds of parameters such as **postional_only**, **var_positioinal**, **kw_only**, and **var_kwargs** into **positional_or_kw**.It also preserves the positional and kword argument names in original function (i.e. f in this example) [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/export.py#L546C13-L546C27). The order of kwargs will be the **key order** of kwargs (after python 3.6, the order is the insertion of order of keys) instead of the original function signature and the order is baked into a _orig_args varaible of gm_torch_level's pytree info. So we'll have: ```python def gm_torch_level(arg1, arg2, kw2, kw1) ``` Such difference is acceptable as it's transparent to users of export. 2. **gm_aot_export = aot_export_module(gm_torch_level, pos_or_kw_args)**. In this step, we need to turn kwargs into positional args in the order of how gm_torch_level expected, which is stored in _orig_args. The returned gm_aot_export has the graph signature of flat_args, in_spec = pytree.tree_flatten(pos_or_kw_args): ``` python flat_args, _ = pytree.tree_flatten(pos_or_kw_args) def gm_aot_export(*flat_args) ``` 3. **exported_program(*args, \*\*kwargs)**. The epxorted artifact is exported_program, which is a wrapper over gm_aot_export and has the same calling convention as the original function "f". To do this, we need to 1. specialize the order of kwargs into pos_or_kw_args and 2. flatten the pos_or_kw_args into what gm_aot_export expected. We can combine the two steps into one with : ```python _, in_spec = pytree.tree_flatten((args, kwargs)) # Then during exported_program.__call__(*args, **kwargs) flat_args = fx_pytree.tree_flatten_spec((args, kwargs), in_spec) ``` , where kwargs is treated as a normal pytree whose keyorder is preserved in in_spec. Implementation-wise, we treat _orig_args in dynamo exported graph module as single source of truth and kwags are ordered following it. Test plan: See added tests in test_export.py. Pull Request resolved: https://github.com/pytorch/pytorch/pull/105337 Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan |
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| community | ||
| compile | ||
| elastic | ||
| notes | ||
| rpc | ||
| scripts | ||
| amp.rst | ||
| autograd.rst | ||
| backends.rst | ||
| benchmark_utils.rst | ||
| bottleneck.rst | ||
| checkpoint.rst | ||
| compiler.rst | ||
| complex_numbers.rst | ||
| conf.py | ||
| config_mod.rst | ||
| cpp_extension.rst | ||
| cpp_index.rst | ||
| cpu.rst | ||
| cuda._sanitizer.rst | ||
| cuda.rst | ||
| cudnn_persistent_rnn.rst | ||
| cudnn_rnn_determinism.rst | ||
| data.rst | ||
| ddp_comm_hooks.rst | ||
| deploy.rst | ||
| distributed.algorithms.join.rst | ||
| distributed.checkpoint.rst | ||
| distributed.elastic.rst | ||
| distributed.optim.rst | ||
| distributed.rst | ||
| distributed.tensor.parallel.rst | ||
| distributions.rst | ||
| dlpack.rst | ||
| docutils.conf | ||
| export.rst | ||
| fft.rst | ||
| fsdp.rst | ||
| func.api.rst | ||
| func.batch_norm.rst | ||
| func.migrating.rst | ||
| func.rst | ||
| func.ux_limitations.rst | ||
| func.whirlwind_tour.rst | ||
| futures.rst | ||
| fx.rst | ||
| hub.rst | ||
| index.rst | ||
| ir.rst | ||
| jit_builtin_functions.rst | ||
| jit_language_reference_v2.rst | ||
| jit_language_reference.rst | ||
| jit_python_reference.rst | ||
| jit_unsupported.rst | ||
| jit_utils.rst | ||
| jit.rst | ||
| library.rst | ||
| linalg.rst | ||
| logging.rst | ||
| masked.rst | ||
| math-quantizer-equation.png | ||
| mobile_optimizer.rst | ||
| model_zoo.rst | ||
| monitor.rst | ||
| mps.rst | ||
| multiprocessing.rst | ||
| name_inference.rst | ||
| named_tensor.rst | ||
| nested.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| onnx_diagnostics.rst | ||
| onnx_supported_aten_ops.rst | ||
| onnx.rst | ||
| optim.rst | ||
| package.rst | ||
| pipeline.rst | ||
| profiler.rst | ||
| quantization-accuracy-debugging.rst | ||
| quantization-backend-configuration.rst | ||
| quantization-support.rst | ||
| quantization.rst | ||
| random.rst | ||
| rpc.rst | ||
| signal.rst | ||
| sparse.rst | ||
| special.rst | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensor_view.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| testing.rst | ||
| torch.ao.ns._numeric_suite_fx.rst | ||
| torch.ao.ns._numeric_suite.rst | ||
| torch.overrides.rst | ||
| torch.rst | ||
| type_info.rst | ||