[ONNX] Squeeze operator should give an error when trying to apply to a dimension with shape > 1 (#38476)

Summary:
The ONNX spec for the Squeeze operator:

> Remove single-dimensional entries from the shape of a tensor. Takes a parameter axes with a list of axes to squeeze. If axes is not provided, all the single dimensions will be removed from the shape. If an axis is selected with shape entry not equal to one, an error is raised.

Currently, as explained in issue https://github.com/pytorch/pytorch/issues/36796, it is possible to export such a model to ONNX, and this results in an exception from ONNX runtime.

Fixes https://github.com/pytorch/pytorch/issues/36796.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/38476

Reviewed By: hl475

Differential Revision: D22158024

Pulled By: houseroad

fbshipit-source-id: bed625f3c626eabcbfb2ea83ec2f992963defa19
This commit is contained in:
Yael Dekel 2020-08-17 17:40:11 -07:00 committed by Facebook GitHub Bot
parent cd96dfd44b
commit 3c5e3966f4
7 changed files with 140 additions and 33 deletions

View File

@ -686,13 +686,64 @@ class TestONNXRuntime(unittest.TestCase):
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)
def test_squeeze(self):
def squeeze_model_tests(self, d, x1, x2):
class Squeeze(torch.nn.Module):
def forward(self, x):
return torch.torch.squeeze(x, dim=-2)
if d is not None:
return torch.squeeze(x, dim=d)
else:
return torch.squeeze(x)
x2 = [] if x2 is None else [x2]
self.run_test(Squeeze(), x1, input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}}, test_with_inputs=x2)
def test_squeeze_without_no_op(self):
x = torch.randn(2, 1, 4)
self.run_test(Squeeze(), x)
self.squeeze_model_tests(1, x, None)
@skipIfUnsupportedMinOpsetVersion(11)
def test_squeeze(self):
x_squeeze = torch.randn(2, 1, 4)
x_noop = torch.randn(2, 2, 3)
self.squeeze_model_tests(1, x_squeeze, x_noop)
def test_squeeze_neg_without_no_op(self):
x = torch.randn(2, 1, 4)
self.squeeze_model_tests(-2, x, None)
@skipIfUnsupportedMinOpsetVersion(11)
def test_squeeze_neg(self):
x_squeeze = torch.randn(2, 1, 4)
x_noop = torch.randn(2, 2, 3)
self.squeeze_model_tests(-2, x_squeeze, x_noop)
def test_squeeze_all_dims(self):
x_squeeze = torch.randn(2, 1, 4)
x_noop = torch.randn(2, 2, 3)
self.squeeze_model_tests(None, x_squeeze, x_noop)
@skipIfUnsupportedMinOpsetVersion(11)
def test_squeeze_no_op(self):
x_noop = torch.randn(2, 1, 4)
x_squeeze = torch.randn(2, 2, 1)
self.squeeze_model_tests(2, x_noop, x_squeeze)
def test_squeeze_no_op_without_additional_inputs(self):
x_noop = torch.randn(2, 1, 4)
self.squeeze_model_tests(2, x_noop, None)
@skipIfUnsupportedMinOpsetVersion(11)
def test_squeeze_runtime_dim(self):
class Squeeze(torch.nn.Module):
def forward(self, d1, d2):
t = torch.zeros(d1[0], d2[0])
return t.squeeze(0)
d1 = torch.tensor([1])
d3 = torch.tensor([3])
d4 = torch.tensor([4])
self.run_test(Squeeze(), (d1, d4), test_with_inputs=[(d3, d4)])
self.run_test(Squeeze(), (d3, d4), test_with_inputs=[(d1, d3)])
def test_unsqueeze(self):
class Unsqueeze(torch.nn.Module):
@ -1568,6 +1619,22 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.randn(3, 4, 5, requires_grad=True)
self.run_test(IndexCopyModel(), x)
def test_select(self):
class Select(torch.nn.Module):
def forward(self, x):
return x[:, 1]
x = torch.randn(3, 4)
self.run_test(Select(), x)
def test_select_negative_index(self):
class Select(torch.nn.Module):
def forward(self, x):
return x[:, -1]
x = torch.randn(3, 4)
self.run_test(Select(), x)
# TODO: enable for opset 10 when ONNXRuntime version will be updated
def test_index_select_constant_scaler_index(self):

View File

@ -49,5 +49,15 @@ void buildParamsMapFromValueToParamsMap(
paramsDict.insert(nameTensorParamPair.second);
}
}
Node* addNodeToBlock(Block* block, Value* input, Symbol kind) {
auto new_node = block->appendNode(block->owningGraph()->create(kind));
auto new_input = new_node->addInput(input);
for (size_t i = 0; i < new_node->outputs().size(); i++) {
auto output = new_node->outputs()[i];
block->registerOutput(output);
}
return new_node;
}
} // namespace jit
} // namespace torch

View File

@ -26,6 +26,7 @@ void eraseUnusedBlockInputs(Block* b);
void buildParamsMapFromValueToParamsMap(
const ValueToParamPairMap& valsToParamsMap,
ParamMap& paramsDict);
Node* addNodeToBlock(Block* block, Value* input, Symbol kind);
} // namespace jit
} // namespace torch

View File

@ -4,6 +4,7 @@
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/python/pybind.h>
#include <torch/csrc/jit/python/python_tracer.h>
@ -13,7 +14,6 @@
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_strings.h>
#include <iostream>
#include <sstream>
@ -467,7 +467,10 @@ void initPythonIRBindings(PyObject* module_) {
return py::make_iterator(b.outputs().begin(), b.outputs().end());
})
.def("returnNode", [](Block& b) { return b.return_node(); })
.def("paramNode", [](Block& b) { return b.param_node(); });
.def("paramNode", [](Block& b) { return b.param_node(); })
.def("addNode", [](Block& b, Value& input, const char* str) {
return addNodeToBlock(&b, &input, Symbol::fromQualString(str));
});
#define NS(name) def(#name, &Node ::name)
py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m, "Node")

View File

@ -497,14 +497,21 @@ def size(g, self, dim=None):
def squeeze(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
if size == 1:
dims.append(i)
else:
dims = [sym_help._get_const(dim, 'i', 'dim')]
return g.op("Squeeze", self, axes_i=dims)
return g.op("Squeeze", self)
dim = sym_help._get_const(dim, 'i', 'dim')
# create 'cond' node (condition is shape[i]==1)
dim_constant = g.op("Constant", value_t=torch.tensor([dim]))
size = sym_help._size_helper(g, self, dim_constant)
const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))
cond = g.op("Equal", size, const_one)
# create the 'If' node and add the 'then' and 'else' blocks to it.
if_node_outputs = g.op("If", cond)
if_node = if_node_outputs.node()
torch.onnx.utils._add_block(if_node, self, "onnx::Squeeze", axes_i=[dim])
torch.onnx.utils._add_block(if_node, self, "onnx::Identity")
return if_node_outputs
@parse_args('v', 'i')
def unsqueeze(g, self, dim):

View File

@ -559,29 +559,42 @@ def select(g, self, dim, index):
def squeeze(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
if size == 1:
dims.append(i)
else:
dims = [sym_help._get_const(dim, 'i', 'dim')]
# Handle negative dims
for i, dim in enumerate(dims):
if dim < 0:
rank = self.type().dim()
if rank:
warnings.warn("ONNX export squeeze with negative axis " + str(dim) +
" might cause the onnx model to be incorrect. " +
"Negative axis is not supported in ONNX. " +
"Axis is converted to " + str(dim + rank) +
" based on input shape at export time. " +
"Passing an tensor of different rank in execution will be incorrect.")
dims[i] += rank
else:
return _unimplemented('squeeze', 'negative axis with unknown input rank')
return g.op("Squeeze", self)
return g.op("Squeeze", self, axes_i=dims)
squeeze_dim = sym_help._get_const(dim, 'i', 'dim')
# Handle negative dims
if squeeze_dim < 0:
rank = self.type().dim()
if rank:
warnings.warn("ONNX export squeeze with negative axis " + str(squeeze_dim) +
" might cause the onnx model to be incorrect. " +
"Negative axis is not supported in ONNX. " +
"Axis is converted to " + str(squeeze_dim + rank) +
" based on input shape at export time. " +
"Passing an tensor of different rank in execution will be incorrect.")
squeeze_dim += rank
else:
return _unimplemented('squeeze', 'negative axis with unknown input rank')
input_shape = self.type().sizes()
if input_shape is None:
warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + " on an input " +
"with unknown shape. Note that if the size of dimension " + str(squeeze_dim) + " of the input " +
"is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " +
"non-singleton dimensions, it is recommended to export this model using opset " +
"version 11 or higher.")
return g.op("Squeeze", self, axes_i=[squeeze_dim])
if input_shape[squeeze_dim] > 1:
warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". The size of " +
"this dimension in the given input is " + str(input_shape[squeeze_dim]) + ". The model will " +
"be exported without the squeeze node. If the model is intended to be used with dynamic " +
"input shapes, please use opset version 11 to " +
"export the model.")
return self
warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". If the model is " +
"intended to be used with dynamic input shapes, please use opset version 11 to export the model.")
return g.op("Squeeze", self, axes_i=[squeeze_dim])
def prelu(g, self, weight):
if self.isCompleteTensor():

View File

@ -986,6 +986,12 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
value_dict[x] = str(key) + '_dynamic_axes_' + str(i + 1)
dynamic_axes[key] = value_dict
def _add_block(node, input_node, op_name, **kwargs):
new_block = node.addBlock()
new_node = new_block.addNode(input_node, op_name)
for k, v in kwargs.items():
_add_attribute(new_node, k, v, False)
torch._C.Graph.op = _graph_op
torch._C.Graph.at = _graph_at
torch._C.Graph.constant = _graph_constant