mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[jit][static] Basic executor (#43647)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43647 Nothing fancy, just a basic implementation of the graph executor without using stack machine. Reviewed By: bwasti Differential Revision: D23208413 fbshipit-source-id: e483bb6ad7ba8591bbe1767e669654d82f42c356
This commit is contained in:
parent
6aaae3b08b
commit
8538a79bfe
83
benchmarks/static_runtime/deep_wide_pt.cc
Normal file
83
benchmarks/static_runtime/deep_wide_pt.cc
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
#include "deep_wide_pt.h"
|
||||
|
||||
#include <torch/csrc/jit/serialization/import_source.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace {
|
||||
// No ReplaceNaN (this removes the constant in the model)
|
||||
const std::string deep_wide_pt = R"JIT(
|
||||
class DeepAndWide(Module):
|
||||
__parameters__ = ["_mu", "_sigma", "_fc_w", "_fc_b", ]
|
||||
__buffers__ = []
|
||||
_mu : Tensor
|
||||
_sigma : Tensor
|
||||
_fc_w : Tensor
|
||||
_fc_b : Tensor
|
||||
training : bool
|
||||
def forward(self: __torch__.DeepAndWide,
|
||||
ad_emb_packed: Tensor,
|
||||
user_emb: Tensor,
|
||||
wide: Tensor) -> Tensor:
|
||||
_0 = self._fc_b
|
||||
_1 = self._fc_w
|
||||
_2 = self._sigma
|
||||
wide_offset = torch.add(wide, self._mu, alpha=1)
|
||||
wide_normalized = torch.mul(wide_offset, _2)
|
||||
wide_preproc = torch.clamp(wide_normalized, 0., 10.)
|
||||
user_emb_t = torch.transpose(user_emb, 1, 2)
|
||||
dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t)
|
||||
dp = torch.flatten(dp_unflatten, 1, -1)
|
||||
input = torch.cat([dp, wide_preproc], 1)
|
||||
fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1)
|
||||
return torch.sigmoid(fc1)
|
||||
)JIT";
|
||||
|
||||
const std::string trivial_model_1 = R"JIT(
|
||||
def forward(self, a, b, c):
|
||||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
return a + b * c + s
|
||||
)JIT";
|
||||
|
||||
void import_libs(
|
||||
std::shared_ptr<at::CompilationUnit> cu,
|
||||
const std::string& class_name,
|
||||
const std::shared_ptr<torch::jit::Source>& src,
|
||||
const std::vector<at::IValue>& tensor_table) {
|
||||
torch::jit::SourceImporter si(
|
||||
cu,
|
||||
&tensor_table,
|
||||
[&](const std::string& /* unused */) -> std::shared_ptr<torch::jit::Source> {
|
||||
return src;
|
||||
},
|
||||
/*version=*/2);
|
||||
si.loadType(c10::QualifiedName(class_name));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
torch::jit::Module getDeepAndWideSciptModel(int num_features) {
|
||||
auto cu = std::make_shared<at::CompilationUnit>();
|
||||
std::vector<at::IValue> constantTable;
|
||||
import_libs(
|
||||
cu,
|
||||
"__torch__.DeepAndWide",
|
||||
std::make_shared<torch::jit::Source>(deep_wide_pt),
|
||||
constantTable);
|
||||
c10::QualifiedName base("__torch__");
|
||||
auto clstype = cu->get_class(c10::QualifiedName(base, "DeepAndWide"));
|
||||
|
||||
torch::jit::Module mod(cu, clstype);
|
||||
|
||||
mod.register_parameter("_mu", torch::randn({1, num_features}), false);
|
||||
mod.register_parameter("_sigma", torch::randn({1, num_features}), false);
|
||||
mod.register_parameter("_fc_w", torch::randn({1, num_features + 1}), false);
|
||||
mod.register_parameter("_fc_b", torch::randn({1}), false);
|
||||
|
||||
// mod.dump(true, true, true);
|
||||
return mod;
|
||||
}
|
||||
|
||||
torch::jit::Module getTrivialScriptModel() {
|
||||
torch::jit::Module module("m");
|
||||
module.define(trivial_model_1);
|
||||
return module;
|
||||
}
|
||||
|
|
@ -1,7 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/serialization/import_source.h>
|
||||
#include <torch/script.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
struct DeepAndWide : torch::nn::Module {
|
||||
|
|
@ -33,69 +31,6 @@ struct DeepAndWide : torch::nn::Module {
|
|||
torch::Tensor mu_, sigma_, fc_w_, fc_b_;
|
||||
};
|
||||
|
||||
namespace {
|
||||
// No ReplaceNaN (this removes the constant in the model)
|
||||
const std::string deep_wide_pt = R"JIT(
|
||||
class DeepAndWide(Module):
|
||||
__parameters__ = ["_mu", "_sigma", "_fc_w", "_fc_b", ]
|
||||
__buffers__ = []
|
||||
_mu : Tensor
|
||||
_sigma : Tensor
|
||||
_fc_w : Tensor
|
||||
_fc_b : Tensor
|
||||
training : bool
|
||||
def forward(self: __torch__.DeepAndWide,
|
||||
ad_emb_packed: Tensor,
|
||||
user_emb: Tensor,
|
||||
wide: Tensor) -> Tensor:
|
||||
_0 = self._fc_b
|
||||
_1 = self._fc_w
|
||||
_2 = self._sigma
|
||||
wide_offset = torch.add(wide, self._mu, alpha=1)
|
||||
wide_normalized = torch.mul(wide_offset, _2)
|
||||
wide_preproc = torch.clamp(wide_normalized, 0., 10.)
|
||||
user_emb_t = torch.transpose(user_emb, 1, 2)
|
||||
dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t)
|
||||
dp = torch.flatten(dp_unflatten, 1, -1)
|
||||
input = torch.cat([dp, wide_preproc], 1)
|
||||
fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1)
|
||||
return torch.sigmoid(fc1)
|
||||
)JIT";
|
||||
torch::jit::Module getDeepAndWideSciptModel(int num_features = 50);
|
||||
|
||||
void import_libs(
|
||||
std::shared_ptr<at::CompilationUnit> cu,
|
||||
const std::string& class_name,
|
||||
const std::shared_ptr<torch::jit::Source>& src,
|
||||
const std::vector<at::IValue>& tensor_table) {
|
||||
torch::jit::SourceImporter si(
|
||||
cu,
|
||||
&tensor_table,
|
||||
[&](const std::string& name) -> std::shared_ptr<torch::jit::Source> {
|
||||
return src;
|
||||
},
|
||||
/*version=*/2);
|
||||
si.loadType(c10::QualifiedName(class_name));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
inline torch::jit::Module getDeepAndWideSciptModel(int num_features = 50) {
|
||||
auto cu = std::make_shared<at::CompilationUnit>();
|
||||
std::vector<at::IValue> constantTable;
|
||||
import_libs(
|
||||
cu,
|
||||
"__torch__.DeepAndWide",
|
||||
std::make_shared<torch::jit::Source>(deep_wide_pt),
|
||||
constantTable);
|
||||
c10::QualifiedName base("__torch__");
|
||||
auto clstype = cu->get_class(c10::QualifiedName(base, "DeepAndWide"));
|
||||
|
||||
torch::jit::Module mod(cu, clstype);
|
||||
|
||||
mod.register_parameter("_mu", torch::randn({1, num_features}), false);
|
||||
mod.register_parameter("_sigma", torch::randn({1, num_features}), false);
|
||||
mod.register_parameter("_fc_w", torch::randn({1, num_features + 1}), false);
|
||||
mod.register_parameter("_fc_b", torch::randn({1}), false);
|
||||
|
||||
// mod.dump(true, true, true);
|
||||
return mod;
|
||||
}
|
||||
torch::jit::Module getTrivialScriptModel();
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#include <benchmark/benchmark.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
|
||||
#include "deep_wide_pt.h"
|
||||
|
||||
const int embedding_size = 32;
|
||||
|
|
|
|||
20
benchmarks/static_runtime/test_static_runtime.cc
Normal file
20
benchmarks/static_runtime/test_static_runtime.cc
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
#include "deep_wide_pt.h"
|
||||
|
||||
TEST(StaticRuntime, TrivialModel) {
|
||||
torch::jit::Module mod = getTrivialScriptModel();
|
||||
auto a = torch::randn({2, 2});
|
||||
auto b = torch::randn({2, 2});
|
||||
auto c = torch::randn({2, 2});
|
||||
|
||||
// run jit graph executor
|
||||
std::vector<at::IValue> input_ivalues({a, b, c});
|
||||
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
||||
|
||||
// run static runtime
|
||||
std::vector<at::Tensor> input_tensors({a, b, c});
|
||||
torch::jit::StaticRuntime runtime(mod);
|
||||
at::Tensor output_2 = runtime.run(input_tensors)[0];
|
||||
EXPECT_TRUE(output_1.equal(output_2));
|
||||
}
|
||||
|
|
@ -2,6 +2,8 @@ import torch
|
|||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
|
||||
class StaticRuntime:
|
||||
def __init__(self, scripted):
|
||||
|
|
@ -90,50 +92,56 @@ def trivial_graph(a, b, c):
|
|||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
return a + b * c + s
|
||||
|
||||
class TestStaticRuntime(TestCase):
|
||||
def test_multihead_attention_layer(self):
|
||||
HID_DIM = 256
|
||||
QUERY_LEN = 8
|
||||
BATCH_SIZE = 128
|
||||
LAYERS = 3
|
||||
HEADS = 8
|
||||
DROPOUT = 0.1
|
||||
device = torch.device("cpu")
|
||||
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
|
||||
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
|
||||
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
|
||||
|
||||
attention.eval()
|
||||
attention = torch.jit.script(attention)
|
||||
attention.eval()
|
||||
o_ref = attention(src, src, src, src_mask)
|
||||
|
||||
attention_a = StaticRuntime(attention)
|
||||
o_test = attention_a(src, src, src, src_mask)
|
||||
for a, b in zip(o_ref, o_test):
|
||||
torch.testing.assert_allclose(a, b)
|
||||
|
||||
def test_mlp(self):
|
||||
# Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
|
||||
ln_bot = [512, 512, 64]
|
||||
sigmoid_bot = -1
|
||||
ln_top = [100, 1024, 1024, 1024, 1]
|
||||
sigmoid_top = 3
|
||||
bot_l = create_mlp(ln_bot, sigmoid_bot)
|
||||
bot_l_acc = StaticRuntime(bot_l)
|
||||
top_l = create_mlp(ln_top, sigmoid_top)
|
||||
top_l_acc = StaticRuntime(top_l)
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
ref_bot = bot_l(bot_inp)
|
||||
acc_bot = bot_l_acc(bot_inp)[0]
|
||||
torch.testing.assert_allclose(acc_bot, ref_bot)
|
||||
ref_top = top_l(top_inp)
|
||||
acc_top = top_l_acc(top_inp)[0]
|
||||
torch.testing.assert_allclose(acc_top, ref_top)
|
||||
|
||||
|
||||
# def test_trivial_graph(self):
|
||||
# s = torch.full((2, 2), 2)
|
||||
# tg = torch.jit.script(trivial_graph)
|
||||
# o_ref = tg(s, s, s)
|
||||
# tg_a = StaticRuntime(tg)
|
||||
# o_test = tg_a(s, s, s)[0]
|
||||
# torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
if __name__ == "__main__":
|
||||
HID_DIM = 256
|
||||
QUERY_LEN = 8
|
||||
BATCH_SIZE = 128
|
||||
LAYERS = 3
|
||||
HEADS = 8
|
||||
DROPOUT = 0.1
|
||||
device = torch.device("cpu")
|
||||
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
|
||||
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
|
||||
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
|
||||
|
||||
attention.eval()
|
||||
attention = torch.jit.script(attention)
|
||||
attention.eval()
|
||||
o_ref = attention(src, src, src, src_mask)
|
||||
|
||||
attention_a = StaticRuntime(attention)
|
||||
o_test = attention_a(src, src, src, src_mask)
|
||||
for a, b in zip(o_ref, o_test):
|
||||
torch.testing.assert_allclose(a, b)
|
||||
|
||||
s = torch.full((2, 2), 2)
|
||||
tg = torch.jit.script(trivial_graph)
|
||||
o_ref = tg(s, s, s)
|
||||
tg_a = StaticRuntime(tg)
|
||||
o_test = tg_a(s, s, s)[0]
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
# Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
|
||||
ln_bot = [512, 512, 64]
|
||||
sigmoid_bot = -1
|
||||
ln_top = [100, 1024, 1024, 1024, 1]
|
||||
sigmoid_top = 3
|
||||
bot_l = create_mlp(ln_bot, sigmoid_bot)
|
||||
bot_l_acc = StaticRuntime(bot_l)
|
||||
top_l = create_mlp(ln_top, sigmoid_top)
|
||||
top_l_acc = StaticRuntime(top_l)
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
ref_bot = bot_l(bot_inp)
|
||||
acc_bot = bot_l_acc(bot_inp)[0]
|
||||
torch.testing.assert_allclose(acc_bot, ref_bot)
|
||||
ref_top = top_l(top_inp)
|
||||
acc_top = top_l_acc(top_inp)[0]
|
||||
torch.testing.assert_allclose(acc_top, ref_top)
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
@ -95,23 +96,122 @@ StaticRuntime::StaticRuntime(const torch::jit::Module& m)
|
|||
%r = static::mul(%y, %s)
|
||||
return (%r))IR");
|
||||
sr.runOnGraph(graph_);
|
||||
code_ = std::make_unique<Code>(graph_, "");
|
||||
interp_ = std::make_unique<InterpreterState>(*code_);
|
||||
|
||||
// remove unused input 0 from graph
|
||||
if (graph_->inputs().at(0)->type()->is_module()) {
|
||||
if (!graph_->inputs().at(0)->hasUses()) {
|
||||
graph_->eraseInput(0);
|
||||
}
|
||||
}
|
||||
|
||||
// fill constant_table_ and operator_table_
|
||||
for (Node* node : graph_->nodes()) {
|
||||
switch (node->kind()) {
|
||||
case prim::Constant:
|
||||
CHECK(node->output()->type()->kind() != FunctionType::Kind);
|
||||
constant_table_[node->output()] = toIValue(node->output()).value();
|
||||
break;
|
||||
case prim::ListConstruct:
|
||||
nodes_.emplace_back(node, nullptr);
|
||||
break;
|
||||
case prim::TupleConstruct:
|
||||
nodes_.emplace_back(node, nullptr);
|
||||
break;
|
||||
default: {
|
||||
const Operator& op = node->getOperator();
|
||||
CHECK(op.hasOperation());
|
||||
nodes_.emplace_back(node, op.getOperation(node));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void StaticRuntime::getInputIValues(
|
||||
Node* node,
|
||||
const ConstantMap& ws,
|
||||
std::vector<IValue>& stack) const {
|
||||
const size_t size = node->inputs().size();
|
||||
stack.reserve(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
Value* v = node->inputs()[i];
|
||||
auto f = constant_table_.find(v);
|
||||
if (f == constant_table_.end()) {
|
||||
auto f_ws = ws.find(v);
|
||||
TORCH_CHECK(
|
||||
f_ws != ws.end(),
|
||||
"Workspace does not contain Value ",
|
||||
v->debugName());
|
||||
stack.emplace_back(f_ws->second);
|
||||
} else {
|
||||
stack.emplace_back(f->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void StaticRuntime::runNodes(ConstantMap& workspace) const {
|
||||
std::vector<IValue> stack;
|
||||
for (const auto& p : nodes_) {
|
||||
Node* node = p.first;
|
||||
const Operation& op = p.second;
|
||||
getInputIValues(node, workspace, stack);
|
||||
VLOG(1) << node->kind().toDisplayString();
|
||||
|
||||
switch (node->kind()) {
|
||||
case prim::ListConstruct: {
|
||||
listConstruct(
|
||||
stack,
|
||||
node->output()->type()->expect<ListType>(),
|
||||
node->inputs().size());
|
||||
} break;
|
||||
case prim::TupleConstruct: {
|
||||
bool named =
|
||||
node->output()->type()->expect<TupleType>()->name().has_value();
|
||||
if (named) {
|
||||
namedTupleConstruct(
|
||||
stack,
|
||||
node->output()->type()->expect<TupleType>(),
|
||||
node->inputs().size());
|
||||
} else {
|
||||
tupleConstruct(stack, node->inputs().size());
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
DCHECK(op);
|
||||
op(&stack);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
DCHECK_EQ(stack.size(), node->outputs().size());
|
||||
for (auto i = 0; i < node->outputs().size(); i++) {
|
||||
workspace[node->outputs()[i]] = stack[i];
|
||||
}
|
||||
stack.clear();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> StaticRuntime::run(
|
||||
const std::vector<at::Tensor>& inps) const {
|
||||
std::vector<torch::jit::IValue> stack;
|
||||
if (graph_->inputs().at(0)->type()->is_module()) {
|
||||
stack.emplace_back(module_._ivalue());
|
||||
}
|
||||
for (const auto& inp : inps) {
|
||||
stack.emplace_back(inp);
|
||||
// Container for inputs, outputs, and activations (excluding parameters)
|
||||
ConstantMap workspace_;
|
||||
|
||||
int start = 0;
|
||||
if (graph_->inputs().size() != inps.size()) {
|
||||
start = 1;
|
||||
CHECK_EQ(graph_->inputs().size(), inps.size() + 1);
|
||||
CHECK((graph_->inputs().at(0)->type()->is_module()));
|
||||
workspace_.emplace(graph_->inputs()[0], module_._ivalue());
|
||||
}
|
||||
|
||||
interp_->run(stack);
|
||||
for (size_t i = 0; i < inps.size(); i++) {
|
||||
workspace_.emplace(graph_->inputs()[i + start], inps[i]);
|
||||
}
|
||||
|
||||
runNodes(workspace_);
|
||||
|
||||
std::vector<at::Tensor> out;
|
||||
for (const auto& v : stack) {
|
||||
for (Value* output : graph_->outputs()) {
|
||||
const IValue& v = workspace_[output];
|
||||
if (v.isTuple()) {
|
||||
auto t = v.toTuple();
|
||||
for (const auto& el : t->elements()) {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@
|
|||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
|
||||
#ifdef FBCODE_CAFFE2
|
||||
#include <folly/container/F14Map.h>
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
|
|
@ -19,13 +23,28 @@ class TORCH_API StaticRuntime {
|
|||
|
||||
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps) const;
|
||||
|
||||
#ifdef FBCODE_CAFFE2
|
||||
using ConstantMap = folly::F14FastMap<Value*, IValue>;
|
||||
#else
|
||||
using ConstantMap = std::unordered_map<Value*, IValue>;
|
||||
#endif
|
||||
|
||||
private:
|
||||
torch::jit::Module module_;
|
||||
std::shared_ptr<torch::jit::Graph> graph_;
|
||||
|
||||
// Jit interpreter state
|
||||
std::unique_ptr<torch::jit::Code> code_;
|
||||
std::unique_ptr<torch::jit::InterpreterState> interp_;
|
||||
// Static runtime states
|
||||
// Constant table (including weights)
|
||||
ConstantMap constant_table_;
|
||||
// The nodes we need to run
|
||||
std::vector<std::pair<Node*, Operation>> nodes_;
|
||||
|
||||
void getInputIValues(
|
||||
Node* node,
|
||||
const ConstantMap& ws,
|
||||
std::vector<IValue>& stack) const;
|
||||
|
||||
void runNodes(ConstantMap& ws_) const;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user