mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Improve Expand shape inference (#69264)
Extend shape inference support for `Expand`, when value of argument `shape` is unknown. Infer the rank of the output of `Expand`, and set shape to dynamic, if shape of argument `shape` is known. Without this, shape inference aborts, and falls back to the static shape provided by tracer, which is incorrect in many cases. Co-authored-by: BowenBao <bowbaomicrosoft.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/72985
This commit is contained in:
parent
08510ba5e4
commit
a6517c20cf
|
|
@ -114,5 +114,14 @@ class TestONNXShapeInference(unittest.TestCase):
|
|||
slice = g.op("Slice", input, start_input, end, axis, step)
|
||||
self.run_test(g, slice.node(), expect_tensor(None, shape=(None, None)))
|
||||
|
||||
def test_expand(self):
|
||||
g = self.create_empty_graph()
|
||||
input = g.addInput()
|
||||
constant = self.insert_tensor_constant(g, torch.ones(2, 4))
|
||||
input.setType(constant.type().with_sizes([None, None]))
|
||||
shape = g.op("Shape", input)
|
||||
expand = g.op("Expand", constant, shape)
|
||||
self.run_test(g, expand.node(), expect_tensor("Float", shape=(None, None)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1374,6 +1374,8 @@ void ComputeConstant(Node* n, int opset_version) {
|
|||
if (input0_shape_size.has_value()) {
|
||||
auto input0_shape_value = input0_shape_size.value();
|
||||
if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
|
||||
// When value of `shape` is statically known,
|
||||
// output shape can be computed.
|
||||
auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
|
||||
n->input(1)->debugName());
|
||||
auto final_shape =
|
||||
|
|
@ -1381,6 +1383,23 @@ void ComputeConstant(Node* n, int opset_version) {
|
|||
if (final_shape.has_value()) {
|
||||
UpdateShape(n->output(), final_shape.value());
|
||||
}
|
||||
} else if (
|
||||
auto expand_shape =
|
||||
ConstantValueMap::GetShapeInto1DInt64VectorWithOneUnknown(
|
||||
n->input(1)->debugName())) {
|
||||
// When shape of `shape` is statically known,
|
||||
// output rank can be computed.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
expand_shape.value().size() == 1,
|
||||
"`Shape` input to `Expand` should be a 1-D tensor. Instead got rank ",
|
||||
expand_shape.value().size());
|
||||
if (expand_shape.value()[0] > 0) {
|
||||
std::vector<c10::ShapeSymbol> final_shape;
|
||||
for (const auto i : c10::irange(expand_shape.value()[0])) {
|
||||
final_shape.emplace_back(c10::ShapeSymbol::newSymbol());
|
||||
}
|
||||
UpdateShape(n->output(), c10::SymbolicShape(final_shape));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user