pytorch/torch/csrc/jit/python
neginraoof 1de3525ca8 [ONNX] Handle PackedParams inputs for _propagate_and_assign_input_shapes (#56449) (#57079)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57079

Testing onnx 1.9 release, we see that the old bug is triggered for the caffe2 test:
`pytest test/onnx/test_pytorch_onnx_caffe2_quantized.py::TestQuantizedOps::test_small_model`
This is because the graph inputs
```python
graph(%x.1 : Tensor,
      %conv1._packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase,
      %conv2._packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase,
      %fc.bias : Float(10, strides=[1], requires_grad=0, device=cpu),
      %fc.weight : Float(10, 72, strides=[72, 1], requires_grad=0, device=cpu)):
```
contains `Conv2dPackedParamsBase` which is a PackedParams.
When we do flatten, we will flatten to several tensors, then the shape inference for input misaligned.
This PR record how may tensors got flattened in PackeParams, and skip by these number rather than 1, then the UT passed.
Note that tuple case should still follow the original logic.

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D28393949

Pulled By: malfet

fbshipit-source-id: 98d48aad27e5ca03fb10d260f8e625478d996ee2

Co-authored-by: David <jiafa@microsoft.com>
2021-05-12 15:20:26 -07:00
..
init.cpp Add pybind type caster for c10::Device (#57292) 2021-05-01 16:11:10 -07:00
init.h
module_python.h
pybind_utils.cpp Revert D27448156: irange for size_t 2021-04-03 19:14:00 -07:00
pybind_utils.h Pass reference to parent future in callbacks (#57635) 2021-05-07 03:59:18 -07:00
pybind.h Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_arg_flatten.cpp Replace all direct cdata access with THPVariable_Unpack (#55799) 2021-04-15 08:57:04 -07:00
python_arg_flatten.h
python_custom_class.cpp Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_custom_class.h
python_interpreter.cpp Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_ir.cpp Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_ir.h Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_ivalue.h Make DataPtr extraction in CUDAFuture faster for Python values (#56918) 2021-05-06 01:12:53 -07:00
python_sugared_value.cpp Add cuda device synchronization support in JIT (#55469) 2021-04-14 09:13:07 -07:00
python_sugared_value.h Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_tracer.cpp [Usability] Capture argument names for traced functions and modules (#51775) 2021-02-10 18:28:08 -08:00
python_tracer.h [Usability] Capture argument names for traced functions and modules (#51775) 2021-02-10 18:28:08 -08:00
python_tree_views.cpp Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_tree_views.h
script_init.cpp [ONNX] Handle PackedParams inputs for _propagate_and_assign_input_shapes (#56449) (#57079) 2021-05-12 15:20:26 -07:00
script_init.h
update_graph_executor_opt.cpp Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
update_graph_executor_opt.h