mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67805 Also fix Reduce ops on binary_cross_entropy_with_logits The graph says the output is a scalar but with `keepdims=1` (the default), the output should be a tensor of rank 1. We set keep `keepdims=0` to make it clear that we want a scalar output. This previously went unnoticed because ONNX Runtime does not strictly enforce shape inference mismatches if the model is not using the latest opset version. Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32181304 Pulled By: malfet fbshipit-source-id: 1462d8a313daae782013097ebf6341a4d1632e2c Co-authored-by: Bowen Bao <bowbao@microsoft.com>
This commit is contained in:
parent
ead59b5ff3
commit
37688148ae
|
|
@ -80,7 +80,7 @@ fi
|
|||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *ort_test2* || "${SHARD_NUMBER}" == "2" ]]; then
|
||||
# Update the loop for new opsets
|
||||
for i in $(seq 10 14); do
|
||||
for i in $(seq 10 15); do
|
||||
pytest "${args[@]}" \
|
||||
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i"
|
||||
done
|
||||
|
|
|
|||
|
|
@ -10324,5 +10324,13 @@ TestONNXRuntime_opset14 = type(str("TestONNXRuntime_opset14"),
|
|||
keep_initializers_as_inputs=False,
|
||||
onnx_shape_inference=True))
|
||||
|
||||
# opset 15 tests
|
||||
TestONNXRuntime_opset15 = type(str("TestONNXRuntime_opset15"),
|
||||
(unittest.TestCase,),
|
||||
dict(TestONNXRuntime.__dict__, opset_version=15,
|
||||
keep_initializers_as_inputs=False,
|
||||
onnx_shape_inference=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -842,8 +842,8 @@ def _handle_reduce_dim_none(g, self, op_name):
|
|||
|
||||
|
||||
_default_onnx_opset_version = 9
|
||||
_onnx_main_opset = 14
|
||||
_onnx_stable_opsets = [7, 8, 9, 10, 11, 12, 13]
|
||||
_onnx_main_opset = 15
|
||||
_onnx_stable_opsets = [7, 8, 9, 10, 11, 12, 13, 14]
|
||||
_export_onnx_opset_version = _default_onnx_opset_version
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -119,9 +119,9 @@ def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduc
|
|||
if reduction == 0:
|
||||
return output
|
||||
elif reduction == 1:
|
||||
return g.op("ReduceMean", output)
|
||||
return g.op("ReduceMean", output, keepdims_i=0)
|
||||
elif reduction == 2:
|
||||
return g.op("ReduceSum", output)
|
||||
return g.op("ReduceSum", output, keepdims_i=0)
|
||||
else:
|
||||
return sym_help._onnx_unsupported("binary_cross_entropy_with_logits with reduction other than none, mean, or sum")
|
||||
|
||||
|
|
|
|||
25
torch/onnx/symbolic_opset15.py
Normal file
25
torch/onnx/symbolic_opset15.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
# This file exports ONNX ops for opset 15
|
||||
|
||||
# Note [ONNX operators that are added/updated in opset 15]
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set
|
||||
# New operators:
|
||||
# Bernoulli
|
||||
# CastLike
|
||||
# Optional
|
||||
# OptionalGetElement
|
||||
# OptionalHasElement
|
||||
#
|
||||
# Updated operators:
|
||||
# BatchNormalization https://github.com/onnx/onnx/pull/3545
|
||||
# Backwards compatible
|
||||
# TODO: test coverage for mixed types inputs.
|
||||
# Pow https://github.com/onnx/onnx/pull/3412
|
||||
# Backwards compatible
|
||||
# TODO: bfloat16 support.
|
||||
# Shape https://github.com/onnx/onnx/pull/3580
|
||||
# Backwards compatible
|
||||
# TODO: optional start/end attribute.
|
||||
Loading…
Reference in New Issue
Block a user