pytorch/torch/csrc/jit/python
BowenBao 3f9c803fe8 [ONNX] Redesign onnx pass to enable shape type dependent pattern conversion - cont (#51795) (#53304)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53304

With the introduction of ONNX shape inference, shape and type are inferred on the fly as operators get converted from ATen to ONNX when running symbolic function. This resolves the shape/type requirement for the symbolic functions. The pre-onnx passes however, can not be supported by shape inference, since at that stage the operators in the graph are still ATen operators.

This PR is to update the design of ONNX pass, to enable a mechanism of capturing subgraphs of ATen operators of certain patterns, and convert them later, when shape/type information of upstream operators are available.

The new design will require pre-onnx passes that need shape/type to be written in two parts, encapsulation and conversion.

    The encapsulation part will find the nodes of patterns, like how pre-onnx passes were written previously. But instead of converting the nodes, it will encapsulate them into a sub-block of a new placeholder node. This part is called before onnx pass, so it runs before calling symbolic functions.

    The conversion part will be called inside the onnx pass. In onnx pass, run_symbolic_func will be called for each node in topological order. When it reaches the placeholder node, the conversion part will be invoked. It will convert the nodes inside the sub-block based on pattern. By that time, it will have shape/type of upstream operators available. After the conversion is complete, the placeholder node will be removed, and nodes inside its sub-block converted. Run_symbolic_func will be called for these nodes, and they will be converted from ATen operator to ONNX operator.

This PR includes several other fixes, listed below.
* ~~replace helper.cpp with onnx_utils.cpp for holding utility functions.~~
* fix EraseNumberTypes on Bool type, the code was outdated that back then Bool type doesn't exist.
* ~~enable onnx shape inference in export with parameter/initializer data.~~
* other code clean ups.
* fix insertion of identity nodes for loop opset 13 sequence output.

~~PR depends on #51603~~

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D26922417

Pulled By: malfet

fbshipit-source-id: 14ed06158d539e2451c2e5e63ba1b32fb0f75095
2021-03-11 10:30:09 -08:00
..
init.cpp [ONNX] Redesign onnx pass to enable shape type dependent pattern conversion - cont (#51795) (#53304) 2021-03-11 10:30:09 -08:00
init.h
module_python.h
pybind_utils.cpp Graceful invalidation of Python Node/Value/Block when C++ object is deleted (#50326) 2021-02-04 01:34:46 -08:00
pybind_utils.h [WIP][FX] Fix tracing support for torchbind (#52884) 2021-03-05 23:40:16 -08:00
pybind.h Graceful invalidation of Python Node/Value/Block when C++ object is deleted (#50326) 2021-02-04 01:34:46 -08:00
python_arg_flatten.cpp [ONNX] ONNX dev branch merge 01-06-2021 (#50163) 2021-01-13 13:51:21 -08:00
python_arg_flatten.h
python_custom_class.cpp [JIT] Add static method support for TorchBind (#51177) 2021-02-13 19:41:27 -08:00
python_custom_class.h
python_interpreter.cpp
python_ir.cpp Context manager for hiding source ranges (#53188) 2021-03-04 09:11:08 -08:00
python_ir.h
python_ivalue.h Remove DataPtr extractor from CUDAFuture (#48840) 2020-12-19 11:03:45 -08:00
python_sugared_value.cpp [JIT] Update Namespace from cuda to _cuda (#53378) 2021-03-11 00:52:01 -08:00
python_sugared_value.h [JIT] Enable ModuleList non-literal indexing (#53410) 2021-03-09 16:11:34 -08:00
python_tracer.cpp [Usability] Capture argument names for traced functions and modules (#51775) 2021-02-10 18:28:08 -08:00
python_tracer.h [Usability] Capture argument names for traced functions and modules (#51775) 2021-02-10 18:28:08 -08:00
python_tree_views.cpp Add dict comprehension (#47774) 2020-12-17 15:25:30 -08:00
python_tree_views.h
script_init.cpp Remove notion of "level" from Module::dump_to_str. (#52539) 2021-03-09 05:45:57 -08:00
script_init.h
update_graph_executor_opt.cpp
update_graph_executor_opt.h