mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Squeeze operator should give an error when trying to apply to a dimension with shape > 1 (#38476)
Summary: The ONNX spec for the Squeeze operator: > Remove single-dimensional entries from the shape of a tensor. Takes a parameter axes with a list of axes to squeeze. If axes is not provided, all the single dimensions will be removed from the shape. If an axis is selected with shape entry not equal to one, an error is raised. Currently, as explained in issue https://github.com/pytorch/pytorch/issues/36796, it is possible to export such a model to ONNX, and this results in an exception from ONNX runtime. Fixes https://github.com/pytorch/pytorch/issues/36796. Pull Request resolved: https://github.com/pytorch/pytorch/pull/38476 Reviewed By: hl475 Differential Revision: D22158024 Pulled By: houseroad fbshipit-source-id: bed625f3c626eabcbfb2ea83ec2f992963defa19
This commit is contained in:
parent
cd96dfd44b
commit
3c5e3966f4
|
|
@ -686,13 +686,64 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
|
||||
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)
|
||||
|
||||
def test_squeeze(self):
|
||||
def squeeze_model_tests(self, d, x1, x2):
|
||||
class Squeeze(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.torch.squeeze(x, dim=-2)
|
||||
if d is not None:
|
||||
return torch.squeeze(x, dim=d)
|
||||
else:
|
||||
return torch.squeeze(x)
|
||||
|
||||
x2 = [] if x2 is None else [x2]
|
||||
self.run_test(Squeeze(), x1, input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}}, test_with_inputs=x2)
|
||||
|
||||
def test_squeeze_without_no_op(self):
|
||||
x = torch.randn(2, 1, 4)
|
||||
self.run_test(Squeeze(), x)
|
||||
self.squeeze_model_tests(1, x, None)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_squeeze(self):
|
||||
x_squeeze = torch.randn(2, 1, 4)
|
||||
x_noop = torch.randn(2, 2, 3)
|
||||
self.squeeze_model_tests(1, x_squeeze, x_noop)
|
||||
|
||||
def test_squeeze_neg_without_no_op(self):
|
||||
x = torch.randn(2, 1, 4)
|
||||
self.squeeze_model_tests(-2, x, None)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_squeeze_neg(self):
|
||||
x_squeeze = torch.randn(2, 1, 4)
|
||||
x_noop = torch.randn(2, 2, 3)
|
||||
self.squeeze_model_tests(-2, x_squeeze, x_noop)
|
||||
|
||||
def test_squeeze_all_dims(self):
|
||||
x_squeeze = torch.randn(2, 1, 4)
|
||||
x_noop = torch.randn(2, 2, 3)
|
||||
self.squeeze_model_tests(None, x_squeeze, x_noop)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_squeeze_no_op(self):
|
||||
x_noop = torch.randn(2, 1, 4)
|
||||
x_squeeze = torch.randn(2, 2, 1)
|
||||
self.squeeze_model_tests(2, x_noop, x_squeeze)
|
||||
|
||||
def test_squeeze_no_op_without_additional_inputs(self):
|
||||
x_noop = torch.randn(2, 1, 4)
|
||||
self.squeeze_model_tests(2, x_noop, None)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_squeeze_runtime_dim(self):
|
||||
class Squeeze(torch.nn.Module):
|
||||
def forward(self, d1, d2):
|
||||
t = torch.zeros(d1[0], d2[0])
|
||||
return t.squeeze(0)
|
||||
|
||||
d1 = torch.tensor([1])
|
||||
d3 = torch.tensor([3])
|
||||
d4 = torch.tensor([4])
|
||||
self.run_test(Squeeze(), (d1, d4), test_with_inputs=[(d3, d4)])
|
||||
self.run_test(Squeeze(), (d3, d4), test_with_inputs=[(d1, d3)])
|
||||
|
||||
def test_unsqueeze(self):
|
||||
class Unsqueeze(torch.nn.Module):
|
||||
|
|
@ -1568,6 +1619,22 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
x = torch.randn(3, 4, 5, requires_grad=True)
|
||||
self.run_test(IndexCopyModel(), x)
|
||||
|
||||
def test_select(self):
|
||||
class Select(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x[:, 1]
|
||||
|
||||
x = torch.randn(3, 4)
|
||||
self.run_test(Select(), x)
|
||||
|
||||
def test_select_negative_index(self):
|
||||
class Select(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x[:, -1]
|
||||
|
||||
x = torch.randn(3, 4)
|
||||
self.run_test(Select(), x)
|
||||
|
||||
# TODO: enable for opset 10 when ONNXRuntime version will be updated
|
||||
|
||||
def test_index_select_constant_scaler_index(self):
|
||||
|
|
|
|||
|
|
@ -49,5 +49,15 @@ void buildParamsMapFromValueToParamsMap(
|
|||
paramsDict.insert(nameTensorParamPair.second);
|
||||
}
|
||||
}
|
||||
|
||||
Node* addNodeToBlock(Block* block, Value* input, Symbol kind) {
|
||||
auto new_node = block->appendNode(block->owningGraph()->create(kind));
|
||||
auto new_input = new_node->addInput(input);
|
||||
for (size_t i = 0; i < new_node->outputs().size(); i++) {
|
||||
auto output = new_node->outputs()[i];
|
||||
block->registerOutput(output);
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ void eraseUnusedBlockInputs(Block* b);
|
|||
void buildParamsMapFromValueToParamsMap(
|
||||
const ValueToParamPairMap& valsToParamsMap,
|
||||
ParamMap& paramsDict);
|
||||
Node* addNodeToBlock(Block* block, Value* input, Symbol kind);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/canonicalize.h>
|
||||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
#include <torch/csrc/jit/python/pybind.h>
|
||||
#include <torch/csrc/jit/python/python_tracer.h>
|
||||
|
|
@ -13,7 +14,6 @@
|
|||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
|
|
@ -467,7 +467,10 @@ void initPythonIRBindings(PyObject* module_) {
|
|||
return py::make_iterator(b.outputs().begin(), b.outputs().end());
|
||||
})
|
||||
.def("returnNode", [](Block& b) { return b.return_node(); })
|
||||
.def("paramNode", [](Block& b) { return b.param_node(); });
|
||||
.def("paramNode", [](Block& b) { return b.param_node(); })
|
||||
.def("addNode", [](Block& b, Value& input, const char* str) {
|
||||
return addNodeToBlock(&b, &input, Symbol::fromQualString(str));
|
||||
});
|
||||
|
||||
#define NS(name) def(#name, &Node ::name)
|
||||
py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m, "Node")
|
||||
|
|
|
|||
|
|
@ -497,14 +497,21 @@ def size(g, self, dim=None):
|
|||
|
||||
def squeeze(g, self, dim=None):
|
||||
if dim is None:
|
||||
dims = []
|
||||
for i, size in enumerate(self.type().sizes()):
|
||||
if size == 1:
|
||||
dims.append(i)
|
||||
else:
|
||||
dims = [sym_help._get_const(dim, 'i', 'dim')]
|
||||
return g.op("Squeeze", self, axes_i=dims)
|
||||
return g.op("Squeeze", self)
|
||||
|
||||
dim = sym_help._get_const(dim, 'i', 'dim')
|
||||
|
||||
# create 'cond' node (condition is shape[i]==1)
|
||||
dim_constant = g.op("Constant", value_t=torch.tensor([dim]))
|
||||
size = sym_help._size_helper(g, self, dim_constant)
|
||||
const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))
|
||||
cond = g.op("Equal", size, const_one)
|
||||
# create the 'If' node and add the 'then' and 'else' blocks to it.
|
||||
if_node_outputs = g.op("If", cond)
|
||||
if_node = if_node_outputs.node()
|
||||
torch.onnx.utils._add_block(if_node, self, "onnx::Squeeze", axes_i=[dim])
|
||||
torch.onnx.utils._add_block(if_node, self, "onnx::Identity")
|
||||
return if_node_outputs
|
||||
|
||||
@parse_args('v', 'i')
|
||||
def unsqueeze(g, self, dim):
|
||||
|
|
|
|||
|
|
@ -559,29 +559,42 @@ def select(g, self, dim, index):
|
|||
|
||||
def squeeze(g, self, dim=None):
|
||||
if dim is None:
|
||||
dims = []
|
||||
for i, size in enumerate(self.type().sizes()):
|
||||
if size == 1:
|
||||
dims.append(i)
|
||||
else:
|
||||
dims = [sym_help._get_const(dim, 'i', 'dim')]
|
||||
# Handle negative dims
|
||||
for i, dim in enumerate(dims):
|
||||
if dim < 0:
|
||||
rank = self.type().dim()
|
||||
if rank:
|
||||
warnings.warn("ONNX export squeeze with negative axis " + str(dim) +
|
||||
" might cause the onnx model to be incorrect. " +
|
||||
"Negative axis is not supported in ONNX. " +
|
||||
"Axis is converted to " + str(dim + rank) +
|
||||
" based on input shape at export time. " +
|
||||
"Passing an tensor of different rank in execution will be incorrect.")
|
||||
dims[i] += rank
|
||||
else:
|
||||
return _unimplemented('squeeze', 'negative axis with unknown input rank')
|
||||
return g.op("Squeeze", self)
|
||||
|
||||
return g.op("Squeeze", self, axes_i=dims)
|
||||
squeeze_dim = sym_help._get_const(dim, 'i', 'dim')
|
||||
# Handle negative dims
|
||||
if squeeze_dim < 0:
|
||||
rank = self.type().dim()
|
||||
if rank:
|
||||
warnings.warn("ONNX export squeeze with negative axis " + str(squeeze_dim) +
|
||||
" might cause the onnx model to be incorrect. " +
|
||||
"Negative axis is not supported in ONNX. " +
|
||||
"Axis is converted to " + str(squeeze_dim + rank) +
|
||||
" based on input shape at export time. " +
|
||||
"Passing an tensor of different rank in execution will be incorrect.")
|
||||
squeeze_dim += rank
|
||||
else:
|
||||
return _unimplemented('squeeze', 'negative axis with unknown input rank')
|
||||
|
||||
input_shape = self.type().sizes()
|
||||
if input_shape is None:
|
||||
warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + " on an input " +
|
||||
"with unknown shape. Note that if the size of dimension " + str(squeeze_dim) + " of the input " +
|
||||
"is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " +
|
||||
"non-singleton dimensions, it is recommended to export this model using opset " +
|
||||
"version 11 or higher.")
|
||||
return g.op("Squeeze", self, axes_i=[squeeze_dim])
|
||||
if input_shape[squeeze_dim] > 1:
|
||||
warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". The size of " +
|
||||
"this dimension in the given input is " + str(input_shape[squeeze_dim]) + ". The model will " +
|
||||
"be exported without the squeeze node. If the model is intended to be used with dynamic " +
|
||||
"input shapes, please use opset version 11 to " +
|
||||
"export the model.")
|
||||
return self
|
||||
|
||||
warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". If the model is " +
|
||||
"intended to be used with dynamic input shapes, please use opset version 11 to export the model.")
|
||||
return g.op("Squeeze", self, axes_i=[squeeze_dim])
|
||||
|
||||
def prelu(g, self, weight):
|
||||
if self.isCompleteTensor():
|
||||
|
|
|
|||
|
|
@ -986,6 +986,12 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
|
|||
value_dict[x] = str(key) + '_dynamic_axes_' + str(i + 1)
|
||||
dynamic_axes[key] = value_dict
|
||||
|
||||
def _add_block(node, input_node, op_name, **kwargs):
|
||||
new_block = node.addBlock()
|
||||
new_node = new_block.addNode(input_node, op_name)
|
||||
for k, v in kwargs.items():
|
||||
_add_attribute(new_node, k, v, False)
|
||||
|
||||
torch._C.Graph.op = _graph_op
|
||||
torch._C.Graph.at = _graph_at
|
||||
torch._C.Graph.constant = _graph_constant
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user