mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64379 * Fix gather squeeze axis in constant folding * mypy * fix indent * address comments Test Plan: Imported from OSS Reviewed By: jansel Differential Revision: D30919604 Pulled By: malfet fbshipit-source-id: 90edb054491433a0da2fe82324ac7c12f1ef062b
This commit is contained in:
parent
41bdfe3919
commit
7e15f2ddaa
|
|
@ -3571,7 +3571,6 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
self.run_test(GatherModel(), input=(input, indices))
|
||||
|
||||
@disableScriptTest() # RuntimeError: Python type cannot be used as a value
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_gather_constant_fold(self):
|
||||
class GatherModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -3602,6 +3601,21 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
x = torch.randn(1, 3, 2)
|
||||
self.run_test(GatherModule(), (x,))
|
||||
|
||||
class GatherModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(GatherModule, self).__init__()
|
||||
self.register_buffer("rb", torch.randn(1, 1, 3, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x += self.rb[0]
|
||||
return x
|
||||
|
||||
x = torch.randn(1, 3, 224, 224)
|
||||
self.run_test(GatherModule(), (x,),
|
||||
dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"},
|
||||
"output": {0: "batch", 1: "class", 2: "height", 3: "width"}},
|
||||
input_names=['input'], output_names=['output'])
|
||||
|
||||
@skipIfUnsupportedOpsetVersion([13])
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_expand(self):
|
||||
|
|
|
|||
|
|
@ -476,7 +476,7 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
|
|||
// If rank of indices is 0, rank of output tensor should be
|
||||
// rank_of_input - 1.
|
||||
if (q < 1) {
|
||||
updated_val = updated_val.squeeze();
|
||||
updated_val = updated_val.squeeze(axis);
|
||||
}
|
||||
return c10::optional<at::Tensor>(updated_val);
|
||||
} else if (node->kind() == onnx::Range) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user