mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46219 - Refactor StaticRuntime and group common data structures, the jit graph, and the script module into a separate struct `InferenceModule`: ``` struct InferenceModule { explicit InferenceModule(const torch::jit::Module& m); explicit InferenceModule(std::shared_ptr<torch::jit::Graph> g); torch::jit::Module module; std::shared_ptr<torch::jit::Graph> graph; std::unique_ptr<c10::FunctionSchema> schema; std::unordered_map<Value*, size_t> value_to_reg; std::vector<size_t> input_regs; // inputs to the graph std::vector<size_t> output_regs; // outputs of the graph std::vector<size_t> internals; }; ``` which is stored in the PyTorchPredictor, as well as the static runtime, and shared across threads. Then this is what's left inside the Static Runtime: ``` mutable std::vector<IValue> reg_; // The nodes we need to run std::vector<ProcessedNode> nodes_; ``` `reg_` holds all the weights and activations, which is different across threads during running. `nodes_` holds the op nodes and input/output registers, and is the same across threads for now. We could potentially put other stateful data structures in it, so I kept it inside the static runtime. It could be easily moved into the `InferenceModule` if we decide not to anything else into `ProcessedNode`. - Added StaticRuntimeOptions so we can toggle certain optimizations on/off, for testing and benchmarking. `cleanup_activations` is an example. - Integration with PyTorchPredictor. Added a lockfree stack in the PyTorchPredictor to hold all the static runtime instances. Benchmark shows that the `push` and `pop` combo takes about 80 ns, which is quite acceptable. This diff focuses on threading model only. Benchmarks will be separate. Reviewed By: bwasti Differential Revision: D24237078 fbshipit-source-id: fd0d6347f02b4526ac17dec1f731db48424bade1 |
||
|---|---|---|
| .. | ||
| CMakeLists.txt | ||
| deep_wide_pt_bench.cc | ||
| deep_wide_pt.cc | ||
| deep_wide_pt.h | ||
| test_static_runtime.cc | ||