mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
ONNX Export Narrow op
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17550 Differential Revision: D14350401 Pulled By: houseroad fbshipit-source-id: 4d88079bb7a8bbd270b0272009826eb3b202cc33
This commit is contained in:
parent
3230404645
commit
3dba1285ab
61
test/onnx/expect/TestOperators.test_narrow.expect
Normal file
61
test/onnx/expect/TestOperators.test_narrow.expect
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.1"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
output: "1"
|
||||
op_type: "Slice"
|
||||
attribute {
|
||||
name: "axes"
|
||||
ints: 0
|
||||
type: INTS
|
||||
}
|
||||
attribute {
|
||||
name: "ends"
|
||||
ints: 2
|
||||
type: INTS
|
||||
}
|
||||
attribute {
|
||||
name: "starts"
|
||||
ints: 0
|
||||
type: INTS
|
||||
}
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
input {
|
||||
name: "0"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "1"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 10
|
||||
}
|
||||
|
|
@ -410,6 +410,10 @@ class TestOperators(TestCase):
|
|||
x = torch.rand(3, 4, requires_grad=True)
|
||||
self.assertONNX(lambda x: x[:, 1:2], x)
|
||||
|
||||
def test_narrow(self):
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.narrow(x, 0, 0, 2), x)
|
||||
|
||||
def test_atan(self):
|
||||
x = torch.randn(3, 4, requires_grad=True)
|
||||
self.assertONNX(lambda x: x.atan(), x)
|
||||
|
|
|
|||
|
|
@ -1107,6 +1107,14 @@ class TestCaffe2Backend(unittest.TestCase):
|
|||
x = torch.randn(2, 3, requires_grad=True)
|
||||
self.run_model_test(ReshapeAsModel(), train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
def test_narrow(self):
|
||||
class NarrowModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return torch.narrow(input, 0, 0, 2)
|
||||
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
self.run_model_test(NarrowModel(), train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
# a bit of metaprogramming to set up all the rnn tests
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1626,6 +1626,11 @@ def nonzero(g, input):
|
|||
return g.op('NonZero', input)
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'i', 'i')
|
||||
def narrow(g, input, dim, start, length):
|
||||
return g.op("Slice", input, axes_i=[dim], starts_i=[start], ends_i=[start + length])
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'i')
|
||||
def _argmax(g, input, dim, keepdim):
|
||||
return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user