[ONNX] Fix gather squeeze axis in constant folding (#63588) (#64379)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64379

* Fix gather squeeze axis in constant folding

* mypy

* fix indent

* address comments

Test Plan: Imported from OSS

Reviewed By: jansel

Differential Revision: D30919604

Pulled By: malfet

fbshipit-source-id: 90edb054491433a0da2fe82324ac7c12f1ef062b
This commit is contained in:
BowenBao 2021-09-30 21:05:46 -07:00 committed by Facebook GitHub Bot
parent 41bdfe3919
commit 7e15f2ddaa
2 changed files with 16 additions and 2 deletions

View File

@ -3571,7 +3571,6 @@ class TestONNXRuntime(unittest.TestCase):
self.run_test(GatherModel(), input=(input, indices))
@disableScriptTest() # RuntimeError: Python type cannot be used as a value
@skipIfUnsupportedMinOpsetVersion(11)
def test_gather_constant_fold(self):
class GatherModule(torch.nn.Module):
def __init__(self):
@ -3602,6 +3601,21 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.randn(1, 3, 2)
self.run_test(GatherModule(), (x,))
class GatherModule(torch.nn.Module):
def __init__(self):
super(GatherModule, self).__init__()
self.register_buffer("rb", torch.randn(1, 1, 3, 1, 1))
def forward(self, x):
x += self.rb[0]
return x
x = torch.randn(1, 3, 224, 224)
self.run_test(GatherModule(), (x,),
dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"},
"output": {0: "batch", 1: "class", 2: "height", 3: "width"}},
input_names=['input'], output_names=['output'])
@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(9)
def test_expand(self):

View File

@ -476,7 +476,7 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
// If rank of indices is 0, rank of output tensor should be
// rank_of_input - 1.
if (q < 1) {
updated_val = updated_val.squeeze();
updated_val = updated_val.squeeze(axis);
}
return c10::optional<at::Tensor>(updated_val);
} else if (node->kind() == onnx::Range) {