Currently (after https://github.com/pytorch/pytorch/pull/114407), the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.
This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model).
That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict``.
This PR adds ``ONNXProgram._model_torch`` attribute to store the user model and demote ``model`` argument of the aforementioned APIs to optional, only (as opposed to required).
As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115281
Approved by: https://github.com/BowenBao
ghstack dependencies: #114407
Currently, the ONNX exporter using torch.nn.Module as input can support
FakeTensor because the ONNX model stores all initializers
When using torch.export.ExportedProgram as input, the initializers are
lifted as inputs. In order to execute the ONNX model, we need to pass a
reference to the non-fake model to the
ONNXProgram.adapt_torch_inputs_to_onnx API, so that initializers can be
fetched from the model and fed to the ONNX model as input
ps: https://github.com/pytorch/pytorch/issues/115461 will track the API revision for the cases where additional `model_with_state_dict` are required to produce complete ONNX files exported with fake support. This is also tracked by the umbrella fake tensor issue https://github.com/pytorch/pytorch/issues/105464 FYI @BowenBao
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114407
Approved by: https://github.com/BowenBao
This PR covers `ExportedProgram` to `test_fx_op_consistency.py`, which helps us identify the necessary but missing io_steps.
Next, we should refactor the tests to actually cover all ops supported by registry.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114886
Approved by: https://github.com/thiagocrepaldi
Previous to this PR, op level debug mismatches whenever it comes to complex dtype matching, because in ONNX, we support real representation. This PR makes sure we use real representation to compare the results.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114885
Approved by: https://github.com/BowenBao
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
This PR adds a unit test that leverages `torch.export.ExportedProgram` models that mutates registered buffers. Although the exporter already works out of the box in such scenario, the GraphModule and the exported ONNX model have extra outputs containing the mutated buffers. On future runs of the ONNX model, the mutated buffers are used as input to the model.
The aforementioned extra inputs and outputs are by design and the `ONNXProgram.model_signature` can be used to fetch detailed input/output schema for the exported model.
However, when we want to compare pytorch output to ONNX's, there is a mismatch between the schema because pytorch output does not include the mutated buffers present on the ONNX output.
This PR extends `onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)` so that the mutated buffers are prepended to the Pytorch output, matching the ONNX schema.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112272
Approved by: https://github.com/titaiwangms, https://github.com/BowenBao
This PR adds a unit test that leverages `torch.export.ExportedProgram` models that mutates registered buffers. Although the exporter already works out of the box in such scenario, the GraphModule and the exported ONNX model have extra outputs containing the mutated buffers. On future runs of the ONNX model, the mutated buffers are used as input to the model.
The aforementioned extra inputs and outputs are by design and the `ONNXProgram.model_signature` can be used to fetch detailed input/output schema for the exported model.
However, when we want to compare pytorch output to ONNX's, there is a mismatch between the schema because pytorch output does not include the mutated buffers present on the ONNX output.
This PR extends `onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)` so that the mutated buffers are prepended to the Pytorch output, matching the ONNX schema.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112272
Approved by: https://github.com/titaiwangms, https://github.com/BowenBao
Although opmath is the right thing to do to retain on-par precision, it inserts
upcasts everywhere in the graph. This is particularly hard for backend to optimize
since there is no way to differentiate between inserted upcasts and model code
casts. Hence we consolidate the input dtype to the result dtype to avoid this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113780
Approved by: https://github.com/titaiwangms, https://github.com/justinchuby
Currently the user can use torch.onnx.dynamo_export to export the model.
to ONNX.
```python
import torch
class Model(torch.nn.Module):
def forward(self, x):
return x + 1.0
onnx_program = torch.onnx.dynamo_export(
Model(),
torch.randn(1, 1, 2, dtype=torch.float),
)
```
The next step would be instantiating a ONNX runtime to execute it.
```python
import onnxruntime # type: ignore[import]
onnx_input = self.adapt_torch_inputs_to_onnx(*args, **kwargs)
options = options or {}
providers = options.get("providers", onnxruntime.get_available_providers())
onnx_model = self.model_proto.SerializeToString()
ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)
def to_numpy(tensor):
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.cpu().numpy()
)
onnxruntime_input = {
k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
}
return ort_session.run(None, onnxruntime_input)
```
This PR provides the `ONNXProgram.__call__` method as facilitator to use ONNX Runtime under the hood, similar to how `torch.export.ExportedProgram.__call__` which allows the underlying `torch.fx.GraphModule` to be executed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113495
Approved by: https://github.com/titaiwangms
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
In cases like #113444, users usually stop at UnsupportedNodeAnalysis with unsupported nodes information. Although in SARIF, they can clearly see it's due to lack of COMPLEX support, in screen error message, it's only showing original FX node name, such as `aten.mul.Tensor`. ~~This PR catches the information from diagnostic messages and reveal it to users.~~
The root cause is that UnsupportedNodeAnalysis is leveraging on `onnxfunction_dispatcher.get_function_overloads()` to decide if an ATen is supported or not. However, in `onnxfunction_dispatcher.get_function_overloads()`, lacking of complex function support is considered unsupported. This PR defines Unsupported FX nodes as not in registry.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113785
Approved by: https://github.com/thiagocrepaldi
When the ONNX model is exported from a torch.export.ExportedProgram, a
torch.export.ExportedGraphSignature is available with the specification
of the model inputs and outputs.
ExportedGraphSignature includes information such as the mapping between
the exported input/buffer/output ONNX name to the original pytorch input/buffer/output name.
It also specifies the kind of the input, such as user_input, parameter,
buffer or constant_tensor. Outputs kind can be user_output, loss_output,
buffer_mutation, etc
Such information can be useful to understand what the ONNX model expects
as inputs and how the output will look like when the ONNX input/output
differs from the original PyTorch input/output schema.
When the ONNX model is exported from a Callable or regular
torch.nn.MOdule, such information is not available and
ONNXProgram.model_signature will yield NOne
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113477
Approved by: https://github.com/BowenBao
After #84624, aten::linalg_vector_norm started being used instead of aten::norm. In the ONNX exporter, the latter leveraged Reduce{L1,L2} when p={1,2}, which resulted in more optimized code in the ONNX Runtime
This PR extends aten::linal_vector_norm to also use Reduce{L1,L2} when ord={1,2}, producing an equivalent ONNX subgraph
This PR is a WIP. Pending work include checking argument equivalence between `aten::norm` and `aten::linalg_vector_norm` and maybe re-enable tests disabled by #84624
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113173
Approved by: https://github.com/justinchuby
Since PyTorch 2.1, torch.export API was introduced and the term "export"
got overloaded due to the already existing torch.onnx.export API.
The torch.onnx.dynamo_export API was introduced on pyTorch 2.0 and it
exposed a torch.onnx.ExportOutput which now can be confused with
torch.export.export output
To prevent such ambiguity and standardize names around the new
torch.export.ExportedProgram, this PR renames torch.onnx.ExportOutput to
torch.onnx.ONNXProgram
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112263
Approved by: https://github.com/BowenBao
ghstack dependencies: #112444
### **Description**:
The problem is that the graph was cast to `fp32` at a certain point but never reverted to `fp16`, causing the rest of the graph to run on `fp32`. This change aims to fix that issue and improve performance.
### **Changes Made**:
- Modified the ONNX exporter code to ensure that the graph is correctly cast back to `fp16` after a necessary cast to `fp32`.
### **Why This Change is Necessary**:
This change is necessary to ensure that the exported ONNX graph remains in `fp16` where appropriate, leading to significant gains in performance and memory savings. Without this fix, the graph would run entirely in `fp32`, causing suboptimal performance.
### **Testing**:
- Performed extensive testing with various models and scenarios to validate the correctness of the changes.
### **Benchmarking Results**:
Experiments Ran on:
8 GPUS - Tesla V100 - 32GB
**Before Fix: ort + 4 hidden layers + without fix**
- **Train Runtime**: 78.7088 seconds
- **Train Samples per Second**: 10.164
- **Train Steps per Second**: 1.271
- **Train Loss**: 5.624655108451844
- **Epoch**: 0.3
**After Fix: ort + 4 hidden layers + with fix**
- **Train Runtime**: 72.5636 seconds
- **Train Samples per Second**: 11.025
- **Train Steps per Second**: 1.378
- **Train Loss**: 5.6252727746963505
- **Epoch**: 0.3
We can see 7.79% perf gain after this fix.
- I only ran it on 4 hidden layers due to GPU constraints, the perf gain is going to be much higher on the full model.
- You could see the gain on other models that uses _attention_scale as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112554
Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi
This PR introduces the ability to produce GraphModules with Core ATen IR only through decompositions. It also removes `None` from user inputs as ONNX does not supports them
Tests for these features will be executed when #112289 is merged, but for reference, they are as below:
```python
def test_log_sigmoid(self):
# This produces op as `torch.ops.aten.log_sigmoid_forward`, instead of the more
# conventional `torch.ops.aten.log_sigmoid`.
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = torch.nn.LogSigmoid()
def forward(self, x):
return self.m(x)
input = torch.randn(2)
self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
Model(), (input,), model_type=self.model_type
)
def test_none_input(self):
class NoneInputModel(torch.nn.Module):
def forward(
self, x: torch.Tensor, y: Optional[torch.Tensor], z: torch.Tensor
):
if y is None:
return x + z
return x + y + z
self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
NoneInputModel(),
(torch.randn(1, 2), None, torch.randn(1, 2)),
model_type=self.model_type,
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112444
Approved by: https://github.com/BowenBao
Summary
- faster than previous try-catch.
- more stable than previous try-catch. In some circumstances serializing models > 2GB into a single protobuf file ends up with a corrupted file without raising an exception.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111984
Approved by: https://github.com/justinchuby
Fixes#109889
This PR adds `torch.export.export` as another `FXGraphExtractor` implementation. `torch.onnx.dynamo_export` automatically uses this new FX tracer when a `torch.export.ExportedProgram` is specified as `model`
Implementation is back compatible, thus non `ExportedProgram` models are handled the exact same way as before
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111497
Approved by: https://github.com/BowenBao
Fixes#109889
This PR adds `torch.export.export` as another `FXGraphExtractor` implementation. `torch.onnx.dynamo_export` automatically uses this new FX tracer when a `torch.export.ExportedProgram` is specified as `model`
Implementation is back compatible, thus non `ExportedProgram` models are handled the exact same way as before
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111497
Approved by: https://github.com/BowenBao
Did some easy fixes from enabling TRY200. Most of these seem like oversights instead of intentional. The proper way to silence intentional errors is with `from None` to note that you thought about whether it should contain the cause and decided against it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111496
Approved by: https://github.com/malfet
Fixes#110597
Summary:
* Generic code: The `torch._C.Value.node().mustBeNone()` is encapsulated into the high-level API `JitScalarType.from_value` ; `_is_none` was also extended to allow either `None` or `torch._C.Value.node.mustBeNone()`, so users don't manually call into TorchScript API when implementing operators
* Specific to `new_zeros` (and ops of ` *_like` and `new_*`): When checking `dtype`, we always must use ` _is_none`, which will call proposed by #110935
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110956
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
When users define customized `attention mask` using `dtype=torch.float16`, e.g.
```
from torch.nn import functional as F
float_min = torch.finfo(torch.float16).min
attention_mask_fp16 = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(torch.float16)
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attention_mask_fp16, 0.0, is_causal=False
)
```
the onnx graph cannot be exported.
When q, k ,v have the fp16 type, we can support this `attn_mask` to be `fp16` type, by adding
```
elif (
_type_utils.JitScalarType.from_value(attn_mask)
== _type_utils.JitScalarType.FLOAT
in (_type_utils.JitScalarType.FLOAT, _type_utils.JitScalarType.HALF)
```
This can export `.onnx` graph.
Fixes#109336
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110306
Approved by: https://github.com/titaiwangms