Part of #79263
Previously, all quantized PyTorch tensors are all casted to the dtypes which comply with ONNX's definition, i.e. `scale` is casted to `double`, and `zero_point` is casted to `int64`. These casts lead to inconsistent dtypes when comparing PyTorch's outputs and ONNX runtime's outputs.
Now, `cast_onnx_accepted` argument is added to `unpack_quantized_tensor` function. When making example inputs for ONNX, we cast them to the ONNX compliant dtypes; otherwise, they are casted to PyTorch default types for quantization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79690
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
When `TrainingMode.PRESERVE` is set for export, the exporter used to change the model's training mode based on some logic. Now we respect the option and not touch the model's training state.
- Previously `_set_training_mode`'s behavior doesn't match what the global variable expects. This PR removes the deprecated `_set_training_mode` and makes the type correct.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78583
Approved by: https://github.com/BowenBao
A graph is exported for each set of inputs. The exported graphs are then compared
to each other, and discrepancies are reported. This function first checks the jit
graph, and then the onnx graph.
Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless
of the inputs it used for exporting. A discrepancy would imply the graph exported is
not accurate when running with other set of inputs, which will typically results in
runtime error or output mismatches.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78323
Approved by: https://github.com/justinchuby, https://github.com/garymm
- Add quantization support for `interpolate`, `avgpool`, `sigmoid` and `add_relu`
- Return the inputs to ListUnpack if the previous node is ListConstruct so that `ListConstruct` and `ListUnpack` are canceled and removed in the jit passes. ONNX doesn't support them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78103
Approved by: https://github.com/garymm