mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix: convert Onnx DynamicSlice operator with 4 inputs to caffe2 fa… (#20846)
Summary: I reported an issue [https://github.com/pytorch/pytorch/issues/20743](url) and make this pull request for it Pull Request resolved: https://github.com/pytorch/pytorch/pull/20846 Reviewed By: zrphercule Differential Revision: D15569135 Pulled By: houseroad fbshipit-source-id: 96a2c818ef666a7d79b96decfa347d7154b34d5c
This commit is contained in:
parent
4b1df5c1f5
commit
34536e207a
|
|
@ -1081,7 +1081,7 @@ Caffe2Ops Caffe2Backend::CreateDynamicSlice(
|
|||
// Axes tensor will be used to populate the fully-specified starts and ends
|
||||
// arguments to the caffe2 Slice operator.
|
||||
std::string axes_tensor;
|
||||
if (onnx_node->node.input_size() > 2) {
|
||||
if (onnx_node->node.input_size() > 3) {
|
||||
axes_tensor = onnx_node->node.input(3);
|
||||
} else {
|
||||
axes_tensor = dummy_->NewDummyName();
|
||||
|
|
|
|||
|
|
@ -1018,9 +1018,9 @@ output: [ 4 6 8 10 12 14 16]
|
|||
.Input(
|
||||
0,
|
||||
"start",
|
||||
"(*Tensor*): [OPTIONAL] scalar tensor containing the start of the interval (inclusive) (default=0)")
|
||||
.Input(1, "stop", "(*Tensor*): scalar tensor containing the end of the interval (exclusive)")
|
||||
.Input(2, "step", "(*Tensor*): [OPTIONAL] scalar tensor specifying the spacing between values (default=1)")
|
||||
"(*Tensor*): [OPTIONAL] scalar or 1-element tensor containing the start of the interval (inclusive) (default=0)")
|
||||
.Input(1, "stop", "(*Tensor*): scalar or 1-element tensor containing the end of the interval (exclusive)")
|
||||
.Input(2, "step", "(*Tensor*): [OPTIONAL] scalar or 1-element tensor specifying the spacing between values (default=1)")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
|
|
|
|||
|
|
@ -1376,7 +1376,7 @@ class RangeOp : public Operator<Context> {
|
|||
T step = 1;
|
||||
|
||||
for (int i = 0; i < InputSize(); ++i) {
|
||||
CAFFE_ENFORCE_EQ(Input(0).dim(), 0, "All inputs must be scalar.");
|
||||
CAFFE_ENFORCE_EQ(Input(i).numel(), 1, "All inputs must be scalar/1D tensor.");
|
||||
}
|
||||
|
||||
switch (InputSize()) {
|
||||
|
|
|
|||
|
|
@ -49,6 +49,46 @@ class TestCaffe2Basic(TestCase):
|
|||
"Don't know how to map unexpected argument (foo|bar)"):
|
||||
b2.convert_node(bad_node_def.SerializeToString())
|
||||
|
||||
def test_dynamicslice_3inputs_graph(self):
|
||||
node_def = make_node(
|
||||
"DynamicSlice", ["X1", "X2", "X3"], ["Y"])
|
||||
|
||||
graph_def = make_graph(
|
||||
[node_def],
|
||||
name="test",
|
||||
inputs=[make_tensor_value_info("X1", onnx.TensorProto.FLOAT, (2, 4)),
|
||||
make_tensor_value_info("X2", onnx.TensorProto.INT32, (1, 2)),
|
||||
make_tensor_value_info("X3", onnx.TensorProto.INT32, (1, 2))],
|
||||
outputs=[make_tensor_value_info("Y", onnx.TensorProto.FLOAT, (1, 2))])
|
||||
model_def = make_model(graph_def, producer_name='caffe2-ref-test')
|
||||
|
||||
x = [[1,2,3,4],[5,6,7,8]]
|
||||
start = [0, 0]
|
||||
end = [-1, 4]
|
||||
prepared = c2.prepare(model_def)
|
||||
output = prepared.run(inputs=[np.array(x), np.array(start), np.array(end)])
|
||||
self.assertSameOutputs(output[0], np.array(x)[0:-1, 0:4])
|
||||
|
||||
def test_dynamicslice_4inputs_graph(self):
|
||||
node_def = make_node(
|
||||
"DynamicSlice", ["X1", "X2", "X3", "axes"], ["Y"])
|
||||
graph_def = make_graph(
|
||||
[node_def],
|
||||
name="test",
|
||||
inputs=[make_tensor_value_info("X1", onnx.TensorProto.FLOAT, (2, 4)),
|
||||
make_tensor_value_info("X2", onnx.TensorProto.INT32, (1, 2)),
|
||||
make_tensor_value_info("X3", onnx.TensorProto.INT32, (1, 2)),
|
||||
make_tensor_value_info("axes", onnx.TensorProto.INT32, (1, 2))],
|
||||
outputs=[make_tensor_value_info("Y", onnx.TensorProto.FLOAT, (1, 2))])
|
||||
model_def = make_model(graph_def, producer_name='caffe2-ref-test')
|
||||
x = [[1,2,3,4],[5,6,7,8]]
|
||||
start = [0, 1]
|
||||
end = [4, 5]
|
||||
axes = [1, 0]
|
||||
prepared = c2.prepare(model_def)
|
||||
output = prepared.run(inputs=[np.array(x), np.array(start), np.array(end), np.array(axes)])
|
||||
self.assertSameOutputs(output[0], np.array(x)[1:5, 0:4])
|
||||
|
||||
def test_relu_graph(self):
|
||||
X = np.random.randn(3, 2).astype(np.float32)
|
||||
Y_ref = np.clip(X, 0, np.inf)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user