pytorch/torch/csrc/jit/passes/onnx/helper.h
BowenBao a6c8730045 [ONNX] Add preprocess pass for onnx export (#41832)
Summary:
in `_jit_pass_onnx`, symbolic functions are called for each node for conversion. However, there are nodes that cannot be converted without additional context. For example, the number of outputs from split (and whether it is static or dynamic) is unknown until the point where it is unpacked by listUnpack node. This pass does a preprocess, and prepares the nodes such that enough context can be received by the symbolic function.
* After preprocessing, `_jit_pass_onnx` should have enough context to produce valid ONNX nodes, instead of half baked nodes that replies on fixes from later postpasses.
* `_jit_pass_onnx_peephole` should be a pass that does ONNX specific optimizations instead of ONNX specific fixes.
* Producing more valid ONNX nodes in `_jit_pass_onnx` enables better utilization of the ONNX shape inference https://github.com/pytorch/pytorch/issues/40628.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41832

Reviewed By: ZolotukhinM

Differential Revision: D22968334

Pulled By: bzinodev

fbshipit-source-id: 8226f03c5b29968e8197d242ca8e620c6e1d42a5
2020-08-06 20:34:12 -07:00

28 lines
695 B
C++

#pragma once
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <memory>
namespace torch {
namespace jit {
namespace onnx {
static const int OPSET_VERSION_1 = 1;
static const int OPSET_VERSION_9 = 9;
static const int OPSET_VERSION_10 = 10;
static const int OPSET_VERSION_11 = 11;
static const int OPSET_VERSION_12 = 12;
} // namespace onnx
using ValueToParamPairMap = std::map<Value*, std::pair<std::string, IValue>>;
using ParamMap = std::map<std::string, IValue>;
ValueToParamPairMap buildValueToParamsMap(Block* b, const ParamMap& paramsDict);
void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap);
} // namespace jit
} // namespace torch