mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Currently (after https://github.com/pytorch/pytorch/pull/114407), the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model). That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict``. This PR adds ``ONNXProgram._model_torch`` attribute to store the user model and demote ``model`` argument of the aforementioned APIs to optional, only (as opposed to required). As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use. Pull Request resolved: https://github.com/pytorch/pytorch/pull/115281 Approved by: https://github.com/BowenBao ghstack dependencies: #114407 |
||
|---|---|---|
| .. | ||
| diagnostics | ||
| fx | ||
| __init__.py | ||
| _beartype.py | ||
| exporter.py | ||
| io_adapter.py | ||
| jit_utils.py | ||
| onnx_proto_utils.py | ||
| onnxruntime.py | ||
| registration.py | ||