mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41606 The previous diff (D22220798 (59294fbbb9) and D22220797) was recently reverted (D22492356 (28291d3cf8), D22492355) because of a bug associated with the op AsyncIf. The AsyncIf op has net_defs as args and the SSA rewriting didn't take that into account. It has a special path for the op If, but not for AsyncIf. Several changes I made to fix the bug: 1) Add op AsyncIf to the special path for If op in SSA rewriting 2) clear inputs/outputs of the netdefs that are args in If/AsyncIf ops because they're no longer valid 3) revert renamed inputs/outputs in the arg netdefs that are in the external_outputs in the parent netdef 2) and 3) are existing bugs in the `SsaRewrite` function that were just never exposed before. The algorithm for `RemoveOpsByType` is the same as in my previous diff D22220798 (59294fbbb9). The only new changes in this diff are in `onnx::SsaRewrite` and a few newly added unit tests. (Note: this ignores all push blocking failures!) Reviewed By: yinghai Differential Revision: D22588652 fbshipit-source-id: ebb68ecd1662ea2bae14d4be8f61a75cd8b7e3e6
27 lines
833 B
C++
27 lines
833 B
C++
#pragma once
|
|
|
|
#include "caffe2/core/workspace.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
/**
|
|
* This struct stores information about the inference graph which defines
|
|
* underlying math of BlackBoxPredictor. Other parts of it such as various
|
|
* threading optimizations don't belong here.
|
|
*/
|
|
struct InferenceGraph {
|
|
std::unique_ptr<NetDef> predict_init_net_def;
|
|
// shared_ptr allows to share NetDef with its operators on each of the threads
|
|
// without memory replication. Note that predict_init_net_def_ could be stored
|
|
// by value as its operators are discarded immidiatly after use (via
|
|
// RunNetOnce)
|
|
std::shared_ptr<NetDef> predict_net_def;
|
|
|
|
std::vector<std::string> input_names;
|
|
std::vector<std::string> output_names;
|
|
std::vector<std::string> parameter_names;
|
|
|
|
bool predictor_net_ssa_rewritten{false};
|
|
};
|
|
} // namespace caffe2
|