Use torch export to get dynamic shapes for JIT converted graph. I just realized we can retrace a converted jit graph with `torch.export` and produce dynamic shapes using `torch.export`.
- **Prior:** The exporter will produce a **static graph silently** even when dynamic_shapes are provided.
- **Proposed:** When `dynamic_shapes` is provided and when the strategy is able to handle it, it will succeed
## Why are we still keeping the JIT strategy?
It is useful when users want to convert JIT modules or `.pt` files into ONNX via the new path. Sometimes also useful when there are JIT scripted modules in the nn module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148627
Approved by: https://github.com/titaiwangms
Use a simple try catch to handle onnx runtime errors in the verification interpreter when that happens. One example is ort will sometimes produce a list of None for some nodes. I am not sure how that happens yet.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148730
Approved by: https://github.com/titaiwangms
ghstack dependencies: #148706
I realized we can just extend `verify_onnx_program` to return intermediate values. There is no need for us to expose the VerificationInterpreter to users.
I added a `compare_intermediates` option to `verify_onnx_program`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148706
Approved by: https://github.com/titaiwangms
Previous to this PR, if the exporting passes run_decomposition(), the report still shows the exported_program before decomposition, which adds the difficulties to our users when they want to check the exported program that are used to translate to ONNX graph.
The following example is what we see before this PR:
```
# PyTorch ONNX Conversion Report
```
✅ Obtain model graph with `torch.export.export(..., strict=False)`
⚪ Obtain model graph with `torch.export.export(..., strict=True)`
⚪ Obtain model graph with `torch.jit.trace`
✅ Decompose operators for ONNX compatibility
❌ Translate the graph into ONNX
⚪ Run `onnx.checker` on the ONNX model
⚪ Execute the model with ONNX Runtime
⚪ Validate model output accuracy
```
## Error messages
```pytb
Traceback (most recent call last):
File "/home/titaiwang/pytorch/torch/onnx/_internal/exporter/_core.py", line 707, in _translate_fx_graph
_handle_call_function_node_with_lowering(
File "/home/titaiwang/pytorch/torch/onnx/_internal/exporter/_core.py", line 486, in _handle_call_function_node_with_lowering
raise _errors.DispatchError(
torch.onnx._internal.exporter._errors.DispatchError: No ONNX function found for <OpOverload(op='aten.slice', overload='Tensor')>. Failure message: No decompositions registered for the complex-valued input
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/titaiwang/pytorch/torch/onnx/_internal/exporter/_core.py", line 1371, in export
onnx_program = _exported_program_to_onnx_program(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/titaiwang/pytorch/torch/onnx/_internal/exporter/_core.py", line 1007, in _exported_program_to_onnx_program
values = _translate_fx_graph(
^^^^^^^^^^^^^^^^^^^^
File "/home/titaiwang/pytorch/torch/onnx/_internal/exporter/_core.py", line 733, in _translate_fx_graph
raise _errors.ConversionError(
torch.onnx._internal.exporter._errors.ConversionError: Error when translating node %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%_to_copy, 0, 0, 9223372036854775807), kwargs = {}). See the stack trace for more information.
```
## Exported program
```python
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 4]"):
# File: /home/titaiwang/pytorch/test_slice_complex.py:6 in forward, code: x_complex = x.to(torch.complex64)
to: "c64[3, 4]" = torch.ops.aten.to.dtype(x, torch.complex64); x = None
# File: /home/titaiwang/pytorch/test_slice_complex.py:8 in forward, code: return x_complex[:, :2]
slice_1: "c64[3, 4]" = torch.ops.aten.slice.Tensor(to, 0, 0, 9223372036854775807); to = None
slice_2: "c64[3, 2]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 2); slice_1 = None
return (slice_2,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_2'), target=None)])
Range constraints: {}
```
## Analysis
PyTorch ONNX Conversion Analysis
## Model Information
The model has 0 parameters and 0 buffers (non-trainable parameters).
Number of parameters per dtype:
```python
defaultdict(<class 'int'>, {})
```
Number of buffers per dtype:
```python
defaultdict(<class 'int'>, {})
```
Inputs:
- `x`: `TensorMetadata(shape=torch.Size([3, 4]), dtype=torch.float32, requires_grad=False, stride=(4, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})`
Outputs:
- `slice_2`: `TensorMetadata(shape=torch.Size([3, 2]), dtype=torch.complex64, requires_grad=False, stride=(4, 1), memory_format=None, is_quantized=False, qparams={})`
The FX graph has 5 nodes in total. Number of FX nodes per op:
- `placeholder`: 1
- `call_function`: 3
- `output`: 1
Of the call_function nodes, the counts of operators used are:
- `aten.slice.Tensor`: 2
- `aten.to.dtype`: 1
## ONNX Conversion Information
The model contains operators the dispatcher could not find registered ONNX decompositions for. This may be due to missing implementations, decompositions not registered correctly, or a bug in the dispatcher.
Errors grouped by operator:
- `aten.to.dtype`: No decompositions registered for the real-valued input. Example node: `%to : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%x, torch.complex64), kwargs = {})`. All nodes: `[to]`
- `aten.slice.Tensor`: No decompositions registered for the complex-valued input. Example node: `%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%to, 0, 0, 9223372036854775807), kwargs = {})`. All nodes: `[slice_1, slice_2]`
## Decomposition comparison
Ops exist only in the ExportedProgram before decomposition: `['aten.to.dtype']`
Ops exist only in the ExportedProgram after decomposition: `['aten._to_copy.default']`
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148617
Approved by: https://github.com/justinchuby
Previously, the comparison of complex numbers was not supported when `verify=True`.
NOTE: This PR can be extended to support more complex comparison cases if there are other places in onnx codebase needed to be changed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148619
Approved by: https://github.com/justinchuby
Previously the strategy used for obtaining the exported program is not asserted. This leads to silent errors if torch.export breaks something and a fallback strategy is used. This change adds a _capture_strategy field to ONNXProgram and enables unit tests to assert the strategy used to prevent fallbacks from happening.
Fixes#147674
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148348
Approved by: https://github.com/titaiwangms, https://github.com/shubhambhokare1
Use onnxscript apis for 2.7.
Remove reference to `torchlib_opset()` and `torchlib_opset_version()` which were removed in the onnxscript 2.7 apis. These apis were removed because torchlib in onnxscript will always stay on opset 18. Future opset version bumps will happen in pytorch core after the migration of torchlib.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148453
Approved by: https://github.com/titaiwangms, https://github.com/shubhambhokare1
Reference: https://docs.astral.sh/ruff/formatter/black/#assert-statements
> Unlike Black, Ruff prefers breaking the message over breaking the assertion, similar to how both Ruff and Black prefer breaking the assignment value over breaking the assignment target:
>
> ```python
> # Input
> assert (
> len(policy_types) >= priority + num_duplicates
> ), f"This tests needs at least {priority+num_duplicates} many types."
>
>
> # Black
> assert (
> len(policy_types) >= priority + num_duplicates
> ), f"This tests needs at least {priority+num_duplicates} many types."
>
> # Ruff
> assert len(policy_types) >= priority + num_duplicates, (
> f"This tests needs at least {priority + num_duplicates} many types."
> )
> ```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144546
Approved by: https://github.com/malfet
In `_any_str_or_dim_in_dynamic_shapes`, we strictly guard the `dynamic_shapes` to make sure the flattened shapes are valid. But the code missed to consider None could be in the shapes.
NOTE: Found in benchmarking with Olive.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148025
Approved by: https://github.com/justinchuby
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This PR sets up the registry to accept onnx decomp functions to be moved into PyTorch (https://github.com/pytorch/pytorch/issues/139301).
The ops from onnx script are currently appended to the registry. When the ops are moved into PyTorch, the moved ops takes precedence because they appear first in the registry list.
After the migration hooks for loading ops from onnx script will be removed.
1. Use a private field `_pt_onnx_signature` to store function signatures to avoid conflicts
2. Update the registry to record the signature in OnnxDecompMeta and update the dispatcher to leverage the data structure
3. Update registry to prepare for onnx op registration, and update the the onnx_impl decorator to support a no_compile option
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147396
Approved by: https://github.com/titaiwangms
Found in `_check_dynamic_shapes` that int and None type are valid inputs of dynamic_shapes.
This PR adds the support on these two types and add the tests to guard the sync of ONNX flatten logic and the one in expor.t
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147407
Approved by: https://github.com/justinchuby
Fixes#143443
This PR aims to support custom dynamic axis naming through dynamic_shapes. Currently, _Dim and _DimHint do not support dynamic axis naming (#144273).
1. **the original dynamic shapes guarantee**
The axis renaming is only applied when dynamic shapes include string instead of all _Dim and _DimHint. Thus, there will not be any inconsistent behavior to dynamic_shapes with torch.export.export if the given dynamic shapes follow torch.export.export format.
2. _DimHint.AUTO is applied to the axes that are specified with custom names to avoid exporter crash. (_DimHint.DYNAMIC crashes when the export fails.)
3. There's no need to handle cases where kwargs are out of order with the model signature,
as torch.export.export supports dynamism only when kwargs and dynamic_shapes are provided in order.
49082f9dba/torch/export/_trace.py (L2034)
4. If `torch.onnx.ExportedProgram` finds the axes share the same constraints, they will have the same name (e.g. s0, s1, ...). Therefore, even if the ONNX users specify them with different custom names, they won't be respected.
Example model:
```python
class NestedModel(torch.nn.Module):
def forward(
self,
x: torch.Tensor,
ys: list[torch.Tensor],
zs: dict[str, torch.Tensor],
c: torch.Tensor,
):
y = ys[0] + ys[1] + zs["a"] + zs["b"]
w = 5
if x.shape[0] < 3 and c.shape[0] != 4:
return x + w, x + y, c
else:
return x - w, x - y, c
input = (
torch.ones(5),
[torch.zeros(5), torch.ones(5)],
{"a": torch.zeros(5), "b": torch.ones(5)},
torch.ones(6),
)
dynamic_shapes = (
{0: torch.export.Dim("dim_x", min=3)}, # _Dim
[("custom_name_axis_ys_0",), (torch.export.Dim.AUTO,)], # custom name
{
"a": {0: torch.export.Dim.AUTO},
"b": ("custom_name_axis_zs_b_0",),
}, # _DimHint
{0: "custom_name_axis_c_0"}, # custom name
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146321
Approved by: https://github.com/justinchuby
Adjust and add deprecation messages to torch.onnx utilities and verification methods because they are only related to torch script and are obsolete.
Removed unused `_exporter_states.py` and removed the internal deprecation module in favor of the typing_extensions deprecated decorator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146639
Approved by: https://github.com/titaiwangms
Reland #146003
Deprecation of `torch.onnx.dynamo_export`:
* [`torch/onnx/_internal/_exporter_legacy.py`]: Added deprecation warnings to the `OnnxRegistry`, `ExportOptions`, `ONNXRuntimeOptions`, and `dynamo_export` functions, indicating that `torch.onnx.dynamo_export` is deprecated since version 2.6.0 and should be replaced with `torch.onnx.export(..., dynamo=True)`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146425
Approved by: https://github.com/titaiwangms, https://github.com/atalman
Deprecation of `torch.onnx.dynamo_export`:
* [`torch/onnx/_internal/_exporter_legacy.py`](diffhunk://#diff-4d1eb96fe68ea904dcd1f8211318b9ff882dbfe4c3cb725ffc164b6c5a58b74cR83-R86): Added deprecation warnings to the `OnnxRegistry`, `ExportOptions`, `ONNXRuntimeOptions`, and `dynamo_export` functions, indicating that `torch.onnx.dynamo_export` is deprecated since version 2.6.0 and should be replaced with `torch.onnx.export(..., dynamo=True)`. [[1]](diffhunk://#diff-4d1eb96fe68ea904dcd1f8211318b9ff882dbfe4c3cb725ffc164b6c5a58b74cR83-R86) [[2]](diffhunk://#diff-4d1eb96fe68ea904dcd1f8211318b9ff882dbfe4c3cb725ffc164b6c5a58b74cR231-R234) [[3]](diffhunk://#diff-4d1eb96fe68ea904dcd1f8211318b9ff882dbfe4c3cb725ffc164b6c5a58b74cR442-R445) [[4]](diffhunk://#diff-4d1eb96fe68ea904dcd1f8211318b9ff882dbfe4c3cb725ffc164b6c5a58b74cR700-R703)
This PR also removed the `**_` kwarg on onnx.export such that users get an error when they supply an unexpected augument.
Updated to emit deprecation warning because it is more appropriate: https://docs.python.org/3/library/exceptions.html#DeprecationWarning
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146003
Approved by: https://github.com/titaiwangms
Basically, this function brings more cons than pros.
It was nice to have an automation help users to convert top-level key of dynamic shapes to arg names. However, this function has a bug when the model input has the same amount as dynamic_shapes in coincidence:
```python
input_names
# 'input_ids', 'past_key_values.0.key', 'past_key_values.0.value', 'past_key_values.1.key', 'past_key_values.1.value', 'past_key_values.2.key', 'past_key_values.2.value', 'past_key_values.3.key', 'past_key_values.3.value', 'past_key_values.4.key', 'past_key_values.4.value', 'attention_mask', 'position_ids'
inspect.sig(model.forward).parameters
# mappingproxy(OrderedDict([('input_ids', <Parameter "input_ids: Optional[torch.LongTensor] = None">), ('past_key_values', <Parameter "past_key_values: Union[transformers.cache_utils.Cache, Tuple[Tuple[torch.Tensor]], NoneType] = None">), ('attention_mask', <Parameter "attention_mask: Optional[torch.FloatTensor] = None">), ('token_type_ids', <Parameter "token_type_ids: Optional[torch.LongTensor] = None">), ('position_ids', <Parameter "position_ids: Optional[torch.LongTensor] = None">), ('head_mask', <Parameter "head_mask: Optional[torch.FloatTensor] = None">), ('inputs_embeds', <Parameter "inputs_embeds: Optional[torch.FloatTensor] = None">), ('labels', <Parameter "labels: Optional[torch.LongTensor] = None">), ('use_cache', <Parameter "use_cache: Optional[bool] = None">), ('output_attentions', <Parameter "output_attentions: Optional[bool] = None">), ('output_hidden_states', <Parameter "output_hidden_states: Optional[bool] = None">), ('return_dict', <Parameter "return_dict: Optional[bool] = None">), ('cache_position', <Parameter "cache_position: Optional[torch.LongTensor] = None">)]))
```
In the above case, the given input_names is following onnx graph, while it has the same length as torch model forward call. This kind of case makes it difficult to detect, and automate for users.
On the other hand, the error message from torch.export.export is quite informative that I believe users will know how to go from there:
```python
import torch
class Model(torch.nn.Module):
def forward(self, x=None, y=None):
return x + y
dim = torch.export.Dim("x", min=1, max=6)
onnx_program = torch.export.export(
Model(),
(),
kwargs={"x": torch.randn(2, 3), "y": torch.randn(2, 3)},
dynamic_shapes={"custom_input_x": {0: dim}, "custom_input_y": {0: dim}},
)
# torch._dynamo.exc.UserError: When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['x', 'y'] of `inputs`, but here they are ['custom_input_x', 'custom_input_y']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146002
Approved by: https://github.com/justinchuby
Fixes#144976
Using appoach ① `IO[bytes]`, but could also try with a protocol.
## Notes:
- moved `torch.serialization.FILE_LIKE` to `torch.types.FileLike`
- Use `FileLike` annotation where it makes sense
- made sure those functions also support `os.PathLike`
- Replaced `isinstance(x, io.BytesIO)` with `isinstance(x, (io.IOBase, IO))` where appropriate.
- Replaced `BinaryIO` with `IO[bytes]` (the two ABCs are almost identical, the only difference is that `BinaryIO` allows `bytearray` input to `write`, whereas `IO[bytes]` only `bytes`)
- needed to make `torch.serialization._opener` generic to avoid LSP violations.
- skipped `torch/onnx/verification` for now (functions use `BytesIO.getvalue` which is not part of the `IO[bytes]` ABC, but it kind of seems that this is redundant, as e.g. `onnx.load` supports `str | PathLike[str] | IO[bytes]` directly...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144994
Approved by: https://github.com/ezyang, https://github.com/Skylion007
Fix#143118
Use python_dispatcher in the type promotion pass to preserve symbolic shapes according to @angelayi 's suggestions. (Thanks!)
Tested locally. I wasn't able to create a minimal repro except for using the full model
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144801
Approved by: https://github.com/titaiwangms