pytorch/benchmarks/static_runtime
Hao Lu 1a3ea46dbf [StaticRuntime] Threading model (#46219)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46219

- Refactor StaticRuntime and group common data structures, the jit graph, and the script module into a separate struct `InferenceModule`:
```
struct InferenceModule {
  explicit InferenceModule(const torch::jit::Module& m);
  explicit InferenceModule(std::shared_ptr<torch::jit::Graph> g);
  torch::jit::Module module;
  std::shared_ptr<torch::jit::Graph> graph;
  std::unique_ptr<c10::FunctionSchema> schema;

  std::unordered_map<Value*, size_t> value_to_reg;
  std::vector<size_t> input_regs; // inputs to the graph
  std::vector<size_t> output_regs; // outputs of the graph
  std::vector<size_t> internals;
};
```
which is stored in the PyTorchPredictor, as well as the static runtime, and shared across threads. Then this is what's left inside the Static Runtime:
```
  mutable std::vector<IValue> reg_;
  // The nodes we need to run
  std::vector<ProcessedNode> nodes_;
```
`reg_` holds all the weights and activations, which is different across threads during running. `nodes_` holds the op nodes and input/output registers, and is the same across threads for now. We could potentially put other stateful data structures in it, so I kept it inside the static runtime. It could be easily moved into the `InferenceModule` if we decide not to anything else into `ProcessedNode`.

- Added StaticRuntimeOptions so we can toggle certain optimizations on/off, for testing and benchmarking. `cleanup_activations` is an example.

- Integration with PyTorchPredictor. Added a lockfree stack in the PyTorchPredictor to hold all the static runtime instances. Benchmark shows that the `push` and `pop` combo takes about 80 ns, which is quite acceptable.

This diff focuses on threading model only. Benchmarks will be separate.

Reviewed By: bwasti

Differential Revision: D24237078

fbshipit-source-id: fd0d6347f02b4526ac17dec1f731db48424bade1
2020-10-20 14:37:30 -07:00
..
CMakeLists.txt [static runtime] Add _out variants and reuse memory (#44128) 2020-09-25 11:03:06 -07:00
deep_wide_pt_bench.cc [StaticRuntime] Threading model (#46219) 2020-10-20 14:37:30 -07:00
deep_wide_pt.cc [jit][static] Basic executor (#43647) 2020-08-28 23:20:07 -07:00
deep_wide_pt.h [StaticRuntime] Add a 'speed of light' benchmark. (#46308) 2020-10-19 23:35:55 -07:00
test_static_runtime.cc [StaticRuntime] Threading model (#46219) 2020-10-20 14:37:30 -07:00