mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46308 This PR adds a hand optimized version of DeepAndWide model with the goal of estimating overheads of static runtime. While static runtime is currently much faster than the existing JIT interpreter, it would be useful to understand how close we are to an absolutely 0-overhead system. Currently, this "ideal" implementation is 2x faster than the static runtime on batchsize=1. Full benchmark results: ``` Running build/bin/static_runtime_bench Run on (24 X 2394.71 MHz CPU s) CPU Caches: L1 Data 32K (x24) L1 Instruction 32K (x24) L2 Unified 4096K (x24) L3 Unified 16384K (x24) ------------------------------------------------------------------------------ Benchmark Time CPU Iterations ------------------------------------------------------------------------------ BM_deep_wide_base/1 59518 ns 59500 ns 10909 BM_deep_wide_base/8 74635 ns 74632 ns 9317 BM_deep_wide_base/20 82186 ns 82147 ns 9119 BM_deep_wide_fast/1 13851 ns 13851 ns 49825 << new BM_deep_wide_fast/8 22497 ns 22497 ns 32089 << new BM_deep_wide_fast/20 23868 ns 23841 ns 31184 << new BM_deep_wide_jit_graph_executor/1 62786 ns 62786 ns 10835 BM_deep_wide_jit_graph_executor/8 76730 ns 76718 ns 7529 BM_deep_wide_jit_graph_executor/20 78886 ns 78883 ns 8769 BM_deep_wide_jit_profiling_executor/1 69504 ns 69490 ns 10309 BM_deep_wide_jit_profiling_executor/8 75718 ns 75715 ns 9199 BM_deep_wide_jit_profiling_executor/20 75364 ns 75364 ns 9010 BM_deep_wide_static/1 40324 ns 40318 ns 17232 BM_deep_wide_static/8 50327 ns 50319 ns 13335 BM_deep_wide_static/20 53075 ns 53071 ns 12855 BM_deep_wide_static_threaded/threads:8 6258 ns 49873 ns 14008 ``` PS: The implementation could probably be optimized even more. Differential Revision: D24300702 Test Plan: Imported from OSS Reviewed By: dzhulgakov Pulled By: ZolotukhinM fbshipit-source-id: 7870bdef127c39d11bcaa4f03a60eb80a46be58e
343 lines
10 KiB
C++
343 lines
10 KiB
C++
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
#include <ATen/core/interned_strings.h>
|
|
#include <caffe2/core/timer.h>
|
|
#include <torch/csrc/jit/passes/canonicalize.h>
|
|
#include <torch/csrc/jit/passes/freeze_module.h>
|
|
#include <torch/csrc/jit/passes/remove_mutation.h>
|
|
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
|
#include <torch/csrc/jit/runtime/static/ops.h>
|
|
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
std::shared_ptr<torch::jit::Graph> PrepareForStaticRuntime(
|
|
std::shared_ptr<torch::jit::Graph> g) {
|
|
Inline(*g);
|
|
ConstantPropagation(g);
|
|
Canonicalize(g);
|
|
ConstantPropagation(g);
|
|
RemoveTensorMutation(g);
|
|
ConstantPropagation(g);
|
|
|
|
for (auto n : g->nodes()) {
|
|
if (n->kind() == c10::Symbol::fromQualString("prim::GetAttr")) {
|
|
throw std::runtime_error("Cannot accelerate unfrozen graphs");
|
|
}
|
|
}
|
|
|
|
// remove unused input 0 from graph
|
|
if (g->inputs().at(0)->type()->is_module()) {
|
|
TORCH_CHECK(!g->inputs().at(0)->hasUses());
|
|
g->eraseInput(0);
|
|
}
|
|
|
|
return g;
|
|
}
|
|
|
|
std::shared_ptr<torch::jit::Graph> PrepareForStaticRuntime(
|
|
const torch::jit::Module& m) {
|
|
auto module = m.copy();
|
|
module.eval();
|
|
module = freeze_module(module);
|
|
auto g = module.get_method("forward").graph();
|
|
return PrepareForStaticRuntime(g);
|
|
}
|
|
|
|
StaticRuntime::StaticRuntime(std::shared_ptr<torch::jit::Graph> g)
|
|
: StaticRuntime(g, c10::nullopt) {}
|
|
|
|
StaticRuntime::StaticRuntime(const torch::jit::Module& m)
|
|
: StaticRuntime(PrepareForStaticRuntime(m), m) {}
|
|
|
|
StaticRuntime::StaticRuntime(
|
|
std::shared_ptr<torch::jit::Graph> g,
|
|
c10::optional<torch::jit::Module> m)
|
|
: graph_(g) {
|
|
// assign register to Value*
|
|
std::unordered_map<Value*, size_t> value_to_reg;
|
|
for (Value* input : g->inputs()) {
|
|
TORCH_CHECK(value_to_reg.count(input) == 0);
|
|
size_t index = value_to_reg.size();
|
|
value_to_reg[input] = index;
|
|
input_regs_.push_back(index);
|
|
}
|
|
for (Node* node : graph_->nodes()) {
|
|
for (Value* input : node->inputs()) {
|
|
TORCH_CHECK(value_to_reg.count(input) > 0);
|
|
}
|
|
for (Value* output : node->outputs()) {
|
|
TORCH_CHECK(
|
|
value_to_reg.count(output) == 0, "the graph needs to be in SSA form");
|
|
size_t index = value_to_reg.size();
|
|
value_to_reg[output] = index;
|
|
}
|
|
}
|
|
|
|
TORCH_CHECK(g->outputs().size() > 0);
|
|
for (Value* output : g->outputs()) {
|
|
TORCH_CHECK(value_to_reg.count(output) > 0);
|
|
output_regs_.push_back(value_to_reg[output]);
|
|
}
|
|
|
|
// initialize registers
|
|
reg_.resize(value_to_reg.size());
|
|
|
|
// fill workspace_ with constants
|
|
for (Node* node : graph_->nodes()) {
|
|
if (node->kind() == prim::Constant) {
|
|
TORCH_CHECK(node->output()->type()->kind() != FunctionType::Kind);
|
|
reg_[value_to_reg[node->output()]] = toIValue(node->output()).value();
|
|
} else {
|
|
std::vector<size_t> input_regs, output_regs;
|
|
for (Value* input : node->inputs()) {
|
|
input_regs.push_back(value_to_reg[input]);
|
|
}
|
|
for (Value* output : node->outputs()) {
|
|
output_regs.push_back(value_to_reg[output]);
|
|
}
|
|
nodes_.emplace_back(node, std::move(input_regs), std::move(output_regs));
|
|
}
|
|
}
|
|
|
|
if (m) {
|
|
Method method = m->get_method("forward");
|
|
const c10::FunctionSchema& schema = method.function().getSchema();
|
|
|
|
// remove "self" from function schema
|
|
TORCH_CHECK(
|
|
schema.arguments().size() >= 1 &&
|
|
schema.arguments()[0].name() == "self");
|
|
std::vector<Argument> args(
|
|
{schema.arguments().begin() + 1, schema.arguments().end()});
|
|
schema_ =
|
|
std::make_unique<c10::FunctionSchema>(schema.cloneWithArguments(args));
|
|
}
|
|
}
|
|
|
|
std::vector<at::Tensor> StaticRuntime::run(
|
|
const std::vector<at::Tensor>& inps) const {
|
|
for (size_t i = 0; i < inps.size(); i++) {
|
|
Input(i) = inps[i];
|
|
}
|
|
|
|
for (const auto& n : nodes_) {
|
|
n.run(reg_);
|
|
}
|
|
|
|
std::vector<at::Tensor> out;
|
|
for (size_t i = 0; i < graph_->outputs().size(); i++) {
|
|
const IValue& v = Output(i);
|
|
if (v.isTuple()) {
|
|
auto t = v.toTuple();
|
|
for (const auto& el : t->elements()) {
|
|
out.emplace_back(el.toTensor());
|
|
}
|
|
} else {
|
|
out.emplace_back(v.toTensor());
|
|
}
|
|
}
|
|
return out;
|
|
}
|
|
|
|
c10::IValue StaticRuntime::run(
|
|
const std::vector<c10::IValue>& args,
|
|
const std::unordered_map<std::string, c10::IValue>& kwargs) const {
|
|
std::vector<IValue> stack(args);
|
|
if (!kwargs.empty()) {
|
|
// This is not ideal
|
|
TORCH_CHECK(
|
|
schema_ != nullptr,
|
|
"Schema is not available. Consider creating the Static Runtime "
|
|
"with StaticRuntime(const torch::jit::Module& m) instead.");
|
|
schema_->checkAndNormalizeInputs(stack, kwargs);
|
|
}
|
|
for (size_t i = 0; i < stack.size(); i++) {
|
|
Input(i) = stack[i];
|
|
}
|
|
|
|
for (const auto& n : nodes_) {
|
|
n.run(reg_);
|
|
}
|
|
|
|
return Output(0);
|
|
}
|
|
|
|
void StaticRuntime::benchmark(
|
|
const std::vector<c10::IValue>& args,
|
|
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
|
const int warmup_runs,
|
|
const int main_runs) const {
|
|
float time_per_iter = benchmark_model(args, kwargs, warmup_runs, main_runs);
|
|
std::cout << "Static runtime ms per iter: " << time_per_iter
|
|
<< ". Iters per second: " << 1000.0 / time_per_iter << std::endl;
|
|
|
|
IndividualMetrics results =
|
|
benchmark_individual_ops(args, kwargs, warmup_runs, main_runs);
|
|
std::cout << "Setting up took " << results.setup_time << " ms" << std::endl;
|
|
|
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
|
const Node* node = nodes_[i].get_node();
|
|
std::cout << "Node #" << i << ": " << results.time_per_node[i]
|
|
<< " ms/iter, ";
|
|
node->print(std::cout, 0, nullptr, false);
|
|
}
|
|
|
|
std::vector<std::pair<std::string, double>> time_per_node_type_vec{
|
|
results.time_per_node_type.begin(), results.time_per_node_type.end()};
|
|
std::sort(
|
|
time_per_node_type_vec.begin(),
|
|
time_per_node_type_vec.end(),
|
|
[](auto& left, auto& right) { return left.second > right.second; });
|
|
|
|
std::cout << "Time per node type:" << std::endl;
|
|
for (const auto& p : time_per_node_type_vec) {
|
|
const std::string& kind = p.first;
|
|
const double ms = p.second;
|
|
std::cout << std::setw(15) << ms << " ms. " << std::setw(10)
|
|
<< results.percent_per_node_type[kind] << "%. " << kind << " ("
|
|
<< results.instances_per_node_type[kind] << " nodes)"
|
|
<< std::endl;
|
|
}
|
|
std::cout << std::setw(15) << results.total_time << " ms. in Total"
|
|
<< std::endl;
|
|
}
|
|
|
|
float StaticRuntime::benchmark_model(
|
|
const std::vector<c10::IValue>& args,
|
|
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
|
const int warmup_runs,
|
|
const int main_runs) const {
|
|
TORCH_CHECK(warmup_runs >= 0 && main_runs >= 1);
|
|
|
|
for (int i = 0; i < warmup_runs; i++) {
|
|
run(args, kwargs);
|
|
}
|
|
caffe2::Timer timer;
|
|
for (int i = 0; i < main_runs; i++) {
|
|
run(args, kwargs);
|
|
}
|
|
float millis = timer.MilliSeconds();
|
|
return millis / static_cast<float>(main_runs);
|
|
}
|
|
|
|
StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
|
const std::vector<c10::IValue>& args,
|
|
const std::unordered_map<std::string, c10::IValue>& kwargs,
|
|
const int warmup_runs,
|
|
const int main_runs) const {
|
|
TORCH_CHECK(warmup_runs >= 0 && main_runs >= 1);
|
|
|
|
IndividualMetrics results;
|
|
results.total_time = 0.0;
|
|
results.time_per_node.resize(nodes_.size(), 0);
|
|
|
|
// setup time
|
|
caffe2::Timer timer;
|
|
std::vector<IValue> stack(args);
|
|
if (!kwargs.empty()) {
|
|
// This is not ideal
|
|
TORCH_CHECK(
|
|
schema_ != nullptr,
|
|
"Schema is not available. Consider creating the Static Runtime "
|
|
"with StaticRuntime(const torch::jit::Module& m) instead.");
|
|
schema_->checkAndNormalizeInputs(stack, kwargs);
|
|
}
|
|
for (size_t i = 0; i < stack.size(); i++) {
|
|
Input(i) = stack[i];
|
|
}
|
|
results.setup_time = timer.MilliSeconds();
|
|
|
|
// warmup runs
|
|
for (int i = 0; i < warmup_runs; i++) {
|
|
run(args, kwargs);
|
|
}
|
|
|
|
// main runs
|
|
for (int i = 0; i < main_runs; i++) {
|
|
for (size_t j = 0; j < nodes_.size(); j++) {
|
|
timer.Start();
|
|
nodes_[j].run(reg_);
|
|
float millis = timer.MilliSeconds();
|
|
results.time_per_node[j] += millis;
|
|
}
|
|
}
|
|
|
|
// post processing
|
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
|
const Node* node = nodes_[i].get_node();
|
|
std::string kind = std::string(node->kind().toQualString());
|
|
results.time_per_node[i] /= static_cast<float>(main_runs);
|
|
results.time_per_node_type[kind] += results.time_per_node[i];
|
|
results.instances_per_node_type[kind]++;
|
|
results.total_time += results.time_per_node[i];
|
|
}
|
|
for (const auto& p : results.time_per_node_type) {
|
|
const std::string& kind = p.first;
|
|
results.percent_per_node_type[kind] = p.second / results.total_time * 100;
|
|
}
|
|
return results;
|
|
}
|
|
|
|
ProcessedNode::ProcessedNode(
|
|
Node* node,
|
|
std::vector<size_t>&& input_regs,
|
|
std::vector<size_t>&& output_regs)
|
|
: node_(node),
|
|
input_regs_(std::move(input_regs)),
|
|
output_regs_(std::move(output_regs)) {
|
|
if (node->kind() != prim::ListConstruct &&
|
|
node->kind() != prim::TupleConstruct &&
|
|
node->kind() != prim::ListUnpack) {
|
|
const Operator& op = node->getOperator();
|
|
TORCH_CHECK(op.hasOperation());
|
|
op_ = op.getOperation(node);
|
|
}
|
|
}
|
|
|
|
void ProcessedNode::run(std::vector<IValue>& reg) const {
|
|
if (!fn_) {
|
|
std::vector<IValue> stack;
|
|
const size_t size = node_->inputs().size();
|
|
stack.reserve(size);
|
|
for (size_t i = 0; i < size; i++) {
|
|
stack.emplace_back(Input(i, reg));
|
|
}
|
|
if (op_) {
|
|
op_->operator()(&stack);
|
|
} else {
|
|
if (node_->kind() == prim::ListConstruct) {
|
|
listConstruct(
|
|
stack,
|
|
node_->output()->type()->expect<ListType>(),
|
|
node_->inputs().size());
|
|
} else if (node_->kind() == prim::TupleConstruct) {
|
|
bool named =
|
|
node_->output()->type()->expect<TupleType>()->name().has_value();
|
|
if (named) {
|
|
namedTupleConstruct(
|
|
stack,
|
|
node_->output()->type()->expect<TupleType>(),
|
|
node_->inputs().size());
|
|
} else {
|
|
tupleConstruct(stack, node_->inputs().size());
|
|
}
|
|
} else if (node_->kind() == prim::ListUnpack) {
|
|
size_t num_outputs = node_->outputs().size();
|
|
listUnpack(stack, num_outputs);
|
|
} else {
|
|
TORCH_CHECK(0, "Unhandled operation!", node_->kind().toQualString());
|
|
}
|
|
}
|
|
DCHECK_EQ(stack.size(), node_->outputs().size());
|
|
for (auto i = 0; i < node_->outputs().size(); i++) {
|
|
Output(i, reg) = std::move(stack[i]);
|
|
}
|
|
} else {
|
|
fn_->operator()(this, reg);
|
|
}
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|