[ONNX] Improve Expand shape inference (#69264)

Extend shape inference support for `Expand`, when value of argument `shape` is unknown. Infer the rank of the output of `Expand`, and set shape to dynamic, if shape of argument `shape` is known.

Without this, shape inference aborts, and falls back to the static shape provided by tracer, which is incorrect in many cases.

Co-authored-by: BowenBao <bowbaomicrosoft.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72985
This commit is contained in:
BowenBao 2022-02-16 18:00:59 -08:00 committed by PyTorch MergeBot
parent 08510ba5e4
commit a6517c20cf
2 changed files with 28 additions and 0 deletions

View File

@ -114,5 +114,14 @@ class TestONNXShapeInference(unittest.TestCase):
slice = g.op("Slice", input, start_input, end, axis, step)
self.run_test(g, slice.node(), expect_tensor(None, shape=(None, None)))
def test_expand(self):
g = self.create_empty_graph()
input = g.addInput()
constant = self.insert_tensor_constant(g, torch.ones(2, 4))
input.setType(constant.type().with_sizes([None, None]))
shape = g.op("Shape", input)
expand = g.op("Expand", constant, shape)
self.run_test(g, expand.node(), expect_tensor("Float", shape=(None, None)))
if __name__ == '__main__':
unittest.main()

View File

@ -1374,6 +1374,8 @@ void ComputeConstant(Node* n, int opset_version) {
if (input0_shape_size.has_value()) {
auto input0_shape_value = input0_shape_size.value();
if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
// When value of `shape` is statically known,
// output shape can be computed.
auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
n->input(1)->debugName());
auto final_shape =
@ -1381,6 +1383,23 @@ void ComputeConstant(Node* n, int opset_version) {
if (final_shape.has_value()) {
UpdateShape(n->output(), final_shape.value());
}
} else if (
auto expand_shape =
ConstantValueMap::GetShapeInto1DInt64VectorWithOneUnknown(
n->input(1)->debugName())) {
// When shape of `shape` is statically known,
// output rank can be computed.
TORCH_INTERNAL_ASSERT(
expand_shape.value().size() == 1,
"`Shape` input to `Expand` should be a 1-D tensor. Instead got rank ",
expand_shape.value().size());
if (expand_shape.value()[0] > 0) {
std::vector<c10::ShapeSymbol> final_shape;
for (const auto i : c10::irange(expand_shape.value()[0])) {
final_shape.emplace_back(c10::ShapeSymbol::newSymbol());
}
UpdateShape(n->output(), c10::SymbolicShape(final_shape));
}
}
}
}