ONNX Export Narrow op

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17550

Differential Revision: D14350401

Pulled By: houseroad

fbshipit-source-id: 4d88079bb7a8bbd270b0272009826eb3b202cc33
This commit is contained in:
Lara Haidar 2019-03-06 22:35:12 -08:00 committed by Facebook Github Bot
parent 3230404645
commit 3dba1285ab
4 changed files with 78 additions and 0 deletions

View File

@ -0,0 +1,61 @@
ir_version: 4
producer_name: "pytorch"
producer_version: "1.1"
graph {
node {
input: "0"
output: "1"
op_type: "Slice"
attribute {
name: "axes"
ints: 0
type: INTS
}
attribute {
name: "ends"
ints: 2
type: INTS
}
attribute {
name: "starts"
ints: 0
type: INTS
}
}
name: "torch-jit-export"
input {
name: "0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 10
}

View File

@ -410,6 +410,10 @@ class TestOperators(TestCase):
x = torch.rand(3, 4, requires_grad=True)
self.assertONNX(lambda x: x[:, 1:2], x)
def test_narrow(self):
x = torch.randn(3, 3, requires_grad=True)
self.assertONNX(lambda x: torch.narrow(x, 0, 0, 2), x)
def test_atan(self):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: x.atan(), x)

View File

@ -1107,6 +1107,14 @@ class TestCaffe2Backend(unittest.TestCase):
x = torch.randn(2, 3, requires_grad=True)
self.run_model_test(ReshapeAsModel(), train=False, input=x, batch_size=BATCH_SIZE)
def test_narrow(self):
class NarrowModel(torch.nn.Module):
def forward(self, input):
return torch.narrow(input, 0, 0, 2)
x = torch.randn(3, 3, requires_grad=True)
self.run_model_test(NarrowModel(), train=False, input=x, batch_size=BATCH_SIZE)
# a bit of metaprogramming to set up all the rnn tests

View File

@ -1626,6 +1626,11 @@ def nonzero(g, input):
return g.op('NonZero', input)
@parse_args('v', 'i', 'i', 'i')
def narrow(g, input, dim, start, length):
return g.op("Slice", input, axes_i=[dim], starts_i=[start], ends_i=[start + length])
@parse_args('v', 'i', 'i')
def _argmax(g, input, dim, keepdim):
return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim)