Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46219
- Refactor StaticRuntime and group common data structures, the jit graph, and the script module into a separate struct `InferenceModule`:
```
struct InferenceModule {
explicit InferenceModule(const torch::jit::Module& m);
explicit InferenceModule(std::shared_ptr<torch::jit::Graph> g);
torch::jit::Module module;
std::shared_ptr<torch::jit::Graph> graph;
std::unique_ptr<c10::FunctionSchema> schema;
std::unordered_map<Value*, size_t> value_to_reg;
std::vector<size_t> input_regs; // inputs to the graph
std::vector<size_t> output_regs; // outputs of the graph
std::vector<size_t> internals;
};
```
which is stored in the PyTorchPredictor, as well as the static runtime, and shared across threads. Then this is what's left inside the Static Runtime:
```
mutable std::vector<IValue> reg_;
// The nodes we need to run
std::vector<ProcessedNode> nodes_;
```
`reg_` holds all the weights and activations, which is different across threads during running. `nodes_` holds the op nodes and input/output registers, and is the same across threads for now. We could potentially put other stateful data structures in it, so I kept it inside the static runtime. It could be easily moved into the `InferenceModule` if we decide not to anything else into `ProcessedNode`.
- Added StaticRuntimeOptions so we can toggle certain optimizations on/off, for testing and benchmarking. `cleanup_activations` is an example.
- Integration with PyTorchPredictor. Added a lockfree stack in the PyTorchPredictor to hold all the static runtime instances. Benchmark shows that the `push` and `pop` combo takes about 80 ns, which is quite acceptable.
This diff focuses on threading model only. Benchmarks will be separate.
Reviewed By: bwasti
Differential Revision: D24237078
fbshipit-source-id: fd0d6347f02b4526ac17dec1f731db48424bade1
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46308
This PR adds a hand optimized version of DeepAndWide model with the goal
of estimating overheads of static runtime. While static runtime is
currently much faster than the existing JIT interpreter, it would be
useful to understand how close we are to an absolutely 0-overhead
system. Currently, this "ideal" implementation is 2x faster than the
static runtime on batchsize=1.
Full benchmark results:
```
Running build/bin/static_runtime_bench
Run on (24 X 2394.71 MHz CPU s)
CPU Caches:
L1 Data 32K (x24)
L1 Instruction 32K (x24)
L2 Unified 4096K (x24)
L3 Unified 16384K (x24)
------------------------------------------------------------------------------
Benchmark Time CPU Iterations
------------------------------------------------------------------------------
BM_deep_wide_base/1 59518 ns 59500 ns 10909
BM_deep_wide_base/8 74635 ns 74632 ns 9317
BM_deep_wide_base/20 82186 ns 82147 ns 9119
BM_deep_wide_fast/1 13851 ns 13851 ns 49825 << new
BM_deep_wide_fast/8 22497 ns 22497 ns 32089 << new
BM_deep_wide_fast/20 23868 ns 23841 ns 31184 << new
BM_deep_wide_jit_graph_executor/1 62786 ns 62786 ns 10835
BM_deep_wide_jit_graph_executor/8 76730 ns 76718 ns 7529
BM_deep_wide_jit_graph_executor/20 78886 ns 78883 ns 8769
BM_deep_wide_jit_profiling_executor/1 69504 ns 69490 ns 10309
BM_deep_wide_jit_profiling_executor/8 75718 ns 75715 ns 9199
BM_deep_wide_jit_profiling_executor/20 75364 ns 75364 ns 9010
BM_deep_wide_static/1 40324 ns 40318 ns 17232
BM_deep_wide_static/8 50327 ns 50319 ns 13335
BM_deep_wide_static/20 53075 ns 53071 ns 12855
BM_deep_wide_static_threaded/threads:8 6258 ns 49873 ns 14008
```
PS: The implementation could probably be optimized even more.
Differential Revision: D24300702
Test Plan: Imported from OSS
Reviewed By: dzhulgakov
Pulled By: ZolotukhinM
fbshipit-source-id: 7870bdef127c39d11bcaa4f03a60eb80a46be58e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43647
Nothing fancy, just a basic implementation of the graph executor without using stack machine.
Reviewed By: bwasti
Differential Revision: D23208413
fbshipit-source-id: e483bb6ad7ba8591bbe1767e669654d82f42c356