mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
The ONNX custom ops registration API.
## Design
1. Create a "custom_translation_table: dict[Callable, Sequence[Callable] | Callable" parameter for specifying extra functions
2. Use a callable as the key to support all possible call_function targets in the fx graph
3. Allow a callable or a Sequence of callables as values.
- When there is a single callable, it is the translation function for the op
- When there is a Sequence of callable, the exporter's dispatcher will dispatch to these callables in order based on input dtypes.
- The translation functions can be a plain python function that calls onnxscript ops (traced), or an onnxscript function.
- Complex input support: We create special type annotations for annotating real representations of complex inputs, which are needed to handle complex computation in the ONNX graph, as we don't have any ops in ONNX that handle complex inputs. The dispatcher will have knowledge of these newly created type annotations and dispatch correctly. The complex functions will be in the same overload pool as the real functions.
```py
torch.onnx.export(dynamo=True,
custom_translation_table = {
torch.ops.aten.add: [overload1, overload2],
torch.sym_not: sym_not_onnx,
})
```
Support for functions that handles complex inputs will be in separate PRs.
fixes https://github.com/pytorch/pytorch/issues/138391
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135403
Approved by: https://github.com/titaiwangms
|
||
|---|---|---|
| .. | ||
| assets | ||
| dynamo | ||
| expect | ||
| exporter | ||
| internal | ||
| model_defs | ||
| torch_export | ||
| autograd_helper.py | ||
| error_reproduction.py | ||
| onnx_test_common.py | ||
| pytorch_test_common.py | ||
| test_autograd_funs.py | ||
| test_custom_ops.py | ||
| test_fx_passes.py | ||
| test_fx_to_onnx_decomp_skip.py | ||
| test_fx_to_onnx.py | ||
| test_fx_type_promotion.py | ||
| test_lazy_import.py | ||
| test_models_onnxruntime.py | ||
| test_models_quantized_onnxruntime.py | ||
| test_models.py | ||
| test_onnx_opset.py | ||
| test_onnxscript_no_runtime.py | ||
| test_onnxscript_runtime.py | ||
| test_op_consistency.py | ||
| test_pytorch_jit_onnx.py | ||
| test_pytorch_onnx_no_runtime.py | ||
| test_pytorch_onnx_onnxruntime_cuda.py | ||
| test_pytorch_onnx_onnxruntime.py | ||
| test_pytorch_onnx_shape_inference.py | ||
| test_symbolic_helper.py | ||
| test_utility_funs.py | ||
| test_verification.py | ||
| verify.py | ||