[jit][static] Basic executor (#43647)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43647

Nothing fancy, just a basic implementation of the graph executor without using stack machine.

Reviewed By: bwasti

Differential Revision: D23208413

fbshipit-source-id: e483bb6ad7ba8591bbe1767e669654d82f42c356
This commit is contained in:
Hao Lu 2020-08-28 23:17:17 -07:00 committed by Facebook GitHub Bot
parent 6aaae3b08b
commit 8538a79bfe
7 changed files with 290 additions and 126 deletions

View File

@ -0,0 +1,83 @@
#include "deep_wide_pt.h"
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/script.h>
namespace {
// No ReplaceNaN (this removes the constant in the model)
const std::string deep_wide_pt = R"JIT(
class DeepAndWide(Module):
__parameters__ = ["_mu", "_sigma", "_fc_w", "_fc_b", ]
__buffers__ = []
_mu : Tensor
_sigma : Tensor
_fc_w : Tensor
_fc_b : Tensor
training : bool
def forward(self: __torch__.DeepAndWide,
ad_emb_packed: Tensor,
user_emb: Tensor,
wide: Tensor) -> Tensor:
_0 = self._fc_b
_1 = self._fc_w
_2 = self._sigma
wide_offset = torch.add(wide, self._mu, alpha=1)
wide_normalized = torch.mul(wide_offset, _2)
wide_preproc = torch.clamp(wide_normalized, 0., 10.)
user_emb_t = torch.transpose(user_emb, 1, 2)
dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t)
dp = torch.flatten(dp_unflatten, 1, -1)
input = torch.cat([dp, wide_preproc], 1)
fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1)
return torch.sigmoid(fc1)
)JIT";
const std::string trivial_model_1 = R"JIT(
def forward(self, a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s
)JIT";
void import_libs(
std::shared_ptr<at::CompilationUnit> cu,
const std::string& class_name,
const std::shared_ptr<torch::jit::Source>& src,
const std::vector<at::IValue>& tensor_table) {
torch::jit::SourceImporter si(
cu,
&tensor_table,
[&](const std::string& /* unused */) -> std::shared_ptr<torch::jit::Source> {
return src;
},
/*version=*/2);
si.loadType(c10::QualifiedName(class_name));
}
} // namespace
torch::jit::Module getDeepAndWideSciptModel(int num_features) {
auto cu = std::make_shared<at::CompilationUnit>();
std::vector<at::IValue> constantTable;
import_libs(
cu,
"__torch__.DeepAndWide",
std::make_shared<torch::jit::Source>(deep_wide_pt),
constantTable);
c10::QualifiedName base("__torch__");
auto clstype = cu->get_class(c10::QualifiedName(base, "DeepAndWide"));
torch::jit::Module mod(cu, clstype);
mod.register_parameter("_mu", torch::randn({1, num_features}), false);
mod.register_parameter("_sigma", torch::randn({1, num_features}), false);
mod.register_parameter("_fc_w", torch::randn({1, num_features + 1}), false);
mod.register_parameter("_fc_b", torch::randn({1}), false);
// mod.dump(true, true, true);
return mod;
}
torch::jit::Module getTrivialScriptModel() {
torch::jit::Module module("m");
module.define(trivial_model_1);
return module;
}

View File

@ -1,7 +1,5 @@
#pragma once
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/script.h>
#include <torch/torch.h>
struct DeepAndWide : torch::nn::Module {
@ -33,69 +31,6 @@ struct DeepAndWide : torch::nn::Module {
torch::Tensor mu_, sigma_, fc_w_, fc_b_;
};
namespace {
// No ReplaceNaN (this removes the constant in the model)
const std::string deep_wide_pt = R"JIT(
class DeepAndWide(Module):
__parameters__ = ["_mu", "_sigma", "_fc_w", "_fc_b", ]
__buffers__ = []
_mu : Tensor
_sigma : Tensor
_fc_w : Tensor
_fc_b : Tensor
training : bool
def forward(self: __torch__.DeepAndWide,
ad_emb_packed: Tensor,
user_emb: Tensor,
wide: Tensor) -> Tensor:
_0 = self._fc_b
_1 = self._fc_w
_2 = self._sigma
wide_offset = torch.add(wide, self._mu, alpha=1)
wide_normalized = torch.mul(wide_offset, _2)
wide_preproc = torch.clamp(wide_normalized, 0., 10.)
user_emb_t = torch.transpose(user_emb, 1, 2)
dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t)
dp = torch.flatten(dp_unflatten, 1, -1)
input = torch.cat([dp, wide_preproc], 1)
fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1)
return torch.sigmoid(fc1)
)JIT";
torch::jit::Module getDeepAndWideSciptModel(int num_features = 50);
void import_libs(
std::shared_ptr<at::CompilationUnit> cu,
const std::string& class_name,
const std::shared_ptr<torch::jit::Source>& src,
const std::vector<at::IValue>& tensor_table) {
torch::jit::SourceImporter si(
cu,
&tensor_table,
[&](const std::string& name) -> std::shared_ptr<torch::jit::Source> {
return src;
},
/*version=*/2);
si.loadType(c10::QualifiedName(class_name));
}
} // namespace
inline torch::jit::Module getDeepAndWideSciptModel(int num_features = 50) {
auto cu = std::make_shared<at::CompilationUnit>();
std::vector<at::IValue> constantTable;
import_libs(
cu,
"__torch__.DeepAndWide",
std::make_shared<torch::jit::Source>(deep_wide_pt),
constantTable);
c10::QualifiedName base("__torch__");
auto clstype = cu->get_class(c10::QualifiedName(base, "DeepAndWide"));
torch::jit::Module mod(cu, clstype);
mod.register_parameter("_mu", torch::randn({1, num_features}), false);
mod.register_parameter("_sigma", torch::randn({1, num_features}), false);
mod.register_parameter("_fc_w", torch::randn({1, num_features + 1}), false);
mod.register_parameter("_fc_b", torch::randn({1}), false);
// mod.dump(true, true, true);
return mod;
}
torch::jit::Module getTrivialScriptModel();

View File

@ -1,6 +1,5 @@
#include <benchmark/benchmark.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include "deep_wide_pt.h"
const int embedding_size = 32;

View File

@ -0,0 +1,20 @@
#include <gtest/gtest.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include "deep_wide_pt.h"
TEST(StaticRuntime, TrivialModel) {
torch::jit::Module mod = getTrivialScriptModel();
auto a = torch::randn({2, 2});
auto b = torch::randn({2, 2});
auto c = torch::randn({2, 2});
// run jit graph executor
std::vector<at::IValue> input_ivalues({a, b, c});
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
// run static runtime
std::vector<at::Tensor> input_tensors({a, b, c});
torch::jit::StaticRuntime runtime(mod);
at::Tensor output_2 = runtime.run(input_tensors)[0];
EXPECT_TRUE(output_1.equal(output_2));
}

View File

@ -2,6 +2,8 @@ import torch
from torch import nn
import numpy as np
from torch.testing._internal.common_utils import TestCase, run_tests
class StaticRuntime:
def __init__(self, scripted):
@ -90,50 +92,56 @@ def trivial_graph(a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s
class TestStaticRuntime(TestCase):
def test_multihead_attention_layer(self):
HID_DIM = 256
QUERY_LEN = 8
BATCH_SIZE = 128
LAYERS = 3
HEADS = 8
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
attention.eval()
attention = torch.jit.script(attention)
attention.eval()
o_ref = attention(src, src, src, src_mask)
attention_a = StaticRuntime(attention)
o_test = attention_a(src, src, src, src_mask)
for a, b in zip(o_ref, o_test):
torch.testing.assert_allclose(a, b)
def test_mlp(self):
# Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
ln_bot = [512, 512, 64]
sigmoid_bot = -1
ln_top = [100, 1024, 1024, 1024, 1]
sigmoid_top = 3
bot_l = create_mlp(ln_bot, sigmoid_bot)
bot_l_acc = StaticRuntime(bot_l)
top_l = create_mlp(ln_top, sigmoid_top)
top_l_acc = StaticRuntime(top_l)
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
torch.testing.assert_allclose(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)
# def test_trivial_graph(self):
# s = torch.full((2, 2), 2)
# tg = torch.jit.script(trivial_graph)
# o_ref = tg(s, s, s)
# tg_a = StaticRuntime(tg)
# o_test = tg_a(s, s, s)[0]
# torch.testing.assert_allclose(o_ref, o_test)
if __name__ == "__main__":
HID_DIM = 256
QUERY_LEN = 8
BATCH_SIZE = 128
LAYERS = 3
HEADS = 8
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
attention.eval()
attention = torch.jit.script(attention)
attention.eval()
o_ref = attention(src, src, src, src_mask)
attention_a = StaticRuntime(attention)
o_test = attention_a(src, src, src, src_mask)
for a, b in zip(o_ref, o_test):
torch.testing.assert_allclose(a, b)
s = torch.full((2, 2), 2)
tg = torch.jit.script(trivial_graph)
o_ref = tg(s, s, s)
tg_a = StaticRuntime(tg)
o_test = tg_a(s, s, s)[0]
torch.testing.assert_allclose(o_ref, o_test)
# Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
ln_bot = [512, 512, 64]
sigmoid_bot = -1
ln_top = [100, 1024, 1024, 1024, 1]
sigmoid_top = 3
bot_l = create_mlp(ln_bot, sigmoid_bot)
bot_l_acc = StaticRuntime(bot_l)
top_l = create_mlp(ln_top, sigmoid_top)
top_l_acc = StaticRuntime(top_l)
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
torch.testing.assert_allclose(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)
run_tests()

View File

@ -4,6 +4,7 @@
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
namespace torch {
namespace jit {
@ -95,23 +96,122 @@ StaticRuntime::StaticRuntime(const torch::jit::Module& m)
%r = static::mul(%y, %s)
return (%r))IR");
sr.runOnGraph(graph_);
code_ = std::make_unique<Code>(graph_, "");
interp_ = std::make_unique<InterpreterState>(*code_);
// remove unused input 0 from graph
if (graph_->inputs().at(0)->type()->is_module()) {
if (!graph_->inputs().at(0)->hasUses()) {
graph_->eraseInput(0);
}
}
// fill constant_table_ and operator_table_
for (Node* node : graph_->nodes()) {
switch (node->kind()) {
case prim::Constant:
CHECK(node->output()->type()->kind() != FunctionType::Kind);
constant_table_[node->output()] = toIValue(node->output()).value();
break;
case prim::ListConstruct:
nodes_.emplace_back(node, nullptr);
break;
case prim::TupleConstruct:
nodes_.emplace_back(node, nullptr);
break;
default: {
const Operator& op = node->getOperator();
CHECK(op.hasOperation());
nodes_.emplace_back(node, op.getOperation(node));
}
}
}
}
void StaticRuntime::getInputIValues(
Node* node,
const ConstantMap& ws,
std::vector<IValue>& stack) const {
const size_t size = node->inputs().size();
stack.reserve(size);
for (size_t i = 0; i < size; i++) {
Value* v = node->inputs()[i];
auto f = constant_table_.find(v);
if (f == constant_table_.end()) {
auto f_ws = ws.find(v);
TORCH_CHECK(
f_ws != ws.end(),
"Workspace does not contain Value ",
v->debugName());
stack.emplace_back(f_ws->second);
} else {
stack.emplace_back(f->second);
}
}
}
void StaticRuntime::runNodes(ConstantMap& workspace) const {
std::vector<IValue> stack;
for (const auto& p : nodes_) {
Node* node = p.first;
const Operation& op = p.second;
getInputIValues(node, workspace, stack);
VLOG(1) << node->kind().toDisplayString();
switch (node->kind()) {
case prim::ListConstruct: {
listConstruct(
stack,
node->output()->type()->expect<ListType>(),
node->inputs().size());
} break;
case prim::TupleConstruct: {
bool named =
node->output()->type()->expect<TupleType>()->name().has_value();
if (named) {
namedTupleConstruct(
stack,
node->output()->type()->expect<TupleType>(),
node->inputs().size());
} else {
tupleConstruct(stack, node->inputs().size());
}
} break;
default: {
DCHECK(op);
op(&stack);
break;
}
}
DCHECK_EQ(stack.size(), node->outputs().size());
for (auto i = 0; i < node->outputs().size(); i++) {
workspace[node->outputs()[i]] = stack[i];
}
stack.clear();
}
}
std::vector<at::Tensor> StaticRuntime::run(
const std::vector<at::Tensor>& inps) const {
std::vector<torch::jit::IValue> stack;
if (graph_->inputs().at(0)->type()->is_module()) {
stack.emplace_back(module_._ivalue());
}
for (const auto& inp : inps) {
stack.emplace_back(inp);
// Container for inputs, outputs, and activations (excluding parameters)
ConstantMap workspace_;
int start = 0;
if (graph_->inputs().size() != inps.size()) {
start = 1;
CHECK_EQ(graph_->inputs().size(), inps.size() + 1);
CHECK((graph_->inputs().at(0)->type()->is_module()));
workspace_.emplace(graph_->inputs()[0], module_._ivalue());
}
interp_->run(stack);
for (size_t i = 0; i < inps.size(); i++) {
workspace_.emplace(graph_->inputs()[i + start], inps[i]);
}
runNodes(workspace_);
std::vector<at::Tensor> out;
for (const auto& v : stack) {
for (Value* output : graph_->outputs()) {
const IValue& v = workspace_[output];
if (v.isTuple()) {
auto t = v.toTuple();
for (const auto& el : t->elements()) {

View File

@ -7,6 +7,10 @@
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/inliner.h>
#ifdef FBCODE_CAFFE2
#include <folly/container/F14Map.h>
#endif
namespace torch {
namespace jit {
@ -19,13 +23,28 @@ class TORCH_API StaticRuntime {
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps) const;
#ifdef FBCODE_CAFFE2
using ConstantMap = folly::F14FastMap<Value*, IValue>;
#else
using ConstantMap = std::unordered_map<Value*, IValue>;
#endif
private:
torch::jit::Module module_;
std::shared_ptr<torch::jit::Graph> graph_;
// Jit interpreter state
std::unique_ptr<torch::jit::Code> code_;
std::unique_ptr<torch::jit::InterpreterState> interp_;
// Static runtime states
// Constant table (including weights)
ConstantMap constant_table_;
// The nodes we need to run
std::vector<std::pair<Node*, Operation>> nodes_;
void getInputIValues(
Node* node,
const ConstantMap& ws,
std::vector<IValue>& stack) const;
void runNodes(ConstantMap& ws_) const;
};
} // namespace jit