[Static Runtime] Support prim::GetAttr/SetAttr (#61505)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61505

The handling of `self` in static runtime was previously incorrect. This diff fixed that issue, since self is essential to prim::GetAttr/SetAttr. After all, most of the time we're getting and setting attributes from self, the torch script module.

Reviewed By: ajyu

Differential Revision: D29350173

fbshipit-source-id: 6e62add4cda517ef8cd6c315d4cb0595e7d531fb
This commit is contained in:
Hao Lu 2021-07-10 14:04:48 -07:00 committed by Facebook GitHub Bot
parent f291b1899f
commit ccd0977060
7 changed files with 262 additions and 128 deletions

View File

@ -1,11 +1,11 @@
import unittest
from typing import Dict, Optional
import numpy as np
import torch
import unittest
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
from typing import Dict, Optional
class StaticModule:
def __init__(self, scripted):
@ -30,7 +30,9 @@ class StaticModule:
)
def linear_shim(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def linear_shim(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
output = input.matmul(weight.t())
if bias is not None:
output += bias
@ -107,7 +109,8 @@ def trivial_graph(a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s
def loop_graph(a, b, iters : int):
def loop_graph(a, b, iters: int):
c = a + b * 2
for i in range(iters):
c = c + b
@ -115,14 +118,50 @@ def loop_graph(a, b, iters : int):
c -= a
return c
def output_graph(a, b, c, iters : int):
def output_graph(a, b, c, iters: int):
s = torch.tensor([[3, 3], [3, 3]])
k = a + b * c + s
d : Dict[int, torch.Tensor] = {}
d: Dict[int, torch.Tensor] = {}
for i in range(iters):
d[i] = k + i
return d
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = 11
self.b = 2
def forward(self, x):
return self.a + self.b + x
class SubModule2(nn.Module):
def __init__(self):
super(SubModule2, self).__init__()
self.a = 12
self.b = 2
def forward(self, x):
self.b = 30
return self.a + self.b + x
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub1 = SubModule()
self.sub2 = SubModule2()
self.a = 3
self.b = 4
def forward(self, x):
self.b = 20
return self.sub1(x) + self.a + self.b + self.sub2(x)
class TestStaticModule(TestCase):
def test_multihead_attention_layer(self):
HID_DIM = 256
@ -220,6 +259,46 @@ class TestStaticModule(TestCase):
o_test = tg_a(s)[0]
torch.testing.assert_allclose(o_ref, o_test)
def test_attr(self):
"""
TorchScript IR of TestModule() after freezing:
graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule,
%x.1 : Tensor):
%18 : int = prim::Constant[value=30]()
%30 : int = prim::Constant[value=13]()
%3 : int = prim::Constant[value=20]()
%2 : int = prim::Constant[value=1]()
%self.sub2.a : int = prim::Constant[value=12]()
%self.a : int = prim::Constant[value=3]()
= prim::SetAttr[name="b"](%self, %3)
%17 : Tensor = aten::add(%x.1, %30, %2)
%7 : Tensor = aten::add(%17, %self.a, %2)
%b.1 : int = prim::GetAttr[name="b"](%self)
%9 : Tensor = aten::add(%7, %b.1, %2)
%sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self)
= prim::SetAttr[name="b"](%sub2, %18)
%b : int = prim::GetAttr[name="b"](%sub2)
%22 : int = aten::add(%self.sub2.a, %b)
%23 : Tensor = aten::add(%x.1, %22, %2)
%12 : Tensor = aten::add(%9, %23, %2)
return (%12)
"""
# test prim::SetAttr and prim::GetAttr impl in Static Runtime
m = TestModule()
m.eval()
input = torch.randn(2, 2)
output_s = m.forward(input)
ms = torch.jit.script(m)
sm = StaticModule(ms)
output_sm = sm(input)[0]
torch.testing.assert_allclose(output_s, output_sm)
sm.benchmark([input], {}, 2, 2)
sm.benchmark_individual_ops([input], {}, 2, 2)
sm.benchmark([], {"x": input}, 2, 2)
sm.benchmark_individual_ops([], {"x": input}, 2, 2)
@unittest.skip("Temporarily disabled")
def test_fusion_trivial_graph(self):
s = torch.full((2, 2), 2)
@ -281,6 +360,5 @@ class TestStaticModule(TestCase):
torch.testing.assert_allclose(o_ref[i], o_test[i])
if __name__ == "__main__":
run_tests()

View File

@ -45,20 +45,29 @@ void OptimizeGraph(
ConstantPropagation(graph);
}
void CheckGraphEligibility(const std::shared_ptr<torch::jit::Graph>& graph) {
for (auto n : graph->nodes()) {
if (n->kind() == c10::Symbol::fromQualString("prim::GetAttr")) {
throw std::runtime_error("Cannot accelerate unfrozen graphs");
bool CheckGraphEligibility(const std::shared_ptr<torch::jit::Graph>& graph) {
// check for sub-blocks
bool can_support = true;
for (auto* node : graph->block()->nodes()) {
for (Block* sub_block : node->blocks()) {
VLOG(1) << "Found nested sub-blocks in graph at node: "
<< PrintNode(node);
can_support = false;
}
}
return can_support;
}
// remove unused input 0 from graph
void RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
bool RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
if (graph->inputs().at(0)->type()->is_module()) {
TORCH_CHECK(!graph->inputs().at(0)->hasUses());
if (graph->inputs().at(0)->hasUses()) {
return false;
}
graph->eraseInput(0);
}
return true;
}
// remove "self" from function schema
@ -443,12 +452,12 @@ GenerateSameStorageValues(
void PrepareGraphForStaticModule(
std::shared_ptr<torch::jit::Graph> graph,
const StaticModuleOptions& opts) {
CheckGraphEligibility(graph);
// TODO: call CheckGraphEligibility before trying to enable static runtime
TORCH_CHECK(CheckGraphEligibility(graph));
OptimizeGraph(graph, opts);
RemoveSelfFromGraphInput(graph);
}
std::pair<std::shared_ptr<Graph>, c10::optional<c10::FunctionSchema>>
std::pair<std::shared_ptr<Graph>, std::shared_ptr<Module>>
PrepareForStaticModule(
const torch::jit::Module& m,
const StaticModuleOptions& opts) {
@ -461,22 +470,23 @@ PrepareForStaticModule(
auto module = m.copy();
module.eval();
module = freeze_module(module);
auto module_ptr = std::make_shared<Module>(freeze_module(module));
Method method = module.get_method("forward");
auto graph = module.get_method("forward").graph();
Method method = module_ptr->get_method("forward");
auto graph = module_ptr->get_method("forward").graph();
// graph->dump();
PrepareGraphForStaticModule(graph, opts);
c10::FunctionSchema s = RemoveSelfFromSchema(method.function().getSchema());
return std::make_pair(graph, s);
return std::make_pair(graph, module_ptr);
}
std::pair<std::shared_ptr<Graph>, c10::optional<c10::FunctionSchema>>
std::pair<std::shared_ptr<Graph>, std::shared_ptr<Module>>
PrepareForStaticModule(
std::shared_ptr<torch::jit::Graph> graph,
const StaticModuleOptions& opts) {
PrepareGraphForStaticModule(graph, opts);
return std::make_pair(graph, c10::nullopt);
return std::make_pair(graph, nullptr);
}
} // namespace
@ -492,13 +502,12 @@ StaticModule::StaticModule(
: StaticModule(PrepareForStaticModule(m, opts), opts) {}
StaticModule::StaticModule(
std::pair<
std::shared_ptr<torch::jit::Graph>,
c10::optional<c10::FunctionSchema>> graph_and_schema,
std::pair<std::shared_ptr<torch::jit::Graph>, std::shared_ptr<Module>>
graph_and_module,
const StaticModuleOptions& opts)
: opts_(opts),
graph_(std::move(graph_and_schema.first)),
schema_(std::move(graph_and_schema.second)) {
graph_(std::move(graph_and_module.first)),
module_(std::move(graph_and_module.second)) {
// check opt flags
if (opts.optimize_graph_output_memory) {
TORCH_CHECK(
@ -511,6 +520,18 @@ StaticModule::StaticModule(
"When optimize_memory is true, enable_out_variant must be set to true");
}
// handle schema
if (module_) {
Method method = module_->get_method("forward");
schema_ = method.function().getSchema();
if (RemoveSelfFromGraphInput(graph_)) {
schema_ = RemoveSelfFromSchema(method.function().getSchema());
} else {
first_input_is_self_ = true;
schema_ = method.function().getSchema();
}
}
// map Value* to IValue (from inputs or prim::Constant) or null
std::unordered_map<Value*, IValue*> value_to_ivalue;
// map Value* to its SSA definition IR
@ -620,6 +641,7 @@ StaticRuntime::StaticRuntime(const StaticModule& sm) : static_module_(sm) {
// NB: create unchanging std::vector<IValue>s we can reference
inputs_.resize(sm.num_inputs());
nodes_.resize(sm.nodes().size());
for (const auto idx : c10::irange(sm.nodes().size())) {
const auto& n_ref = sm.nodes()[idx];
nodes_[idx] = n_ref; // copy the node
@ -688,6 +710,43 @@ std::vector<at::Tensor> StaticRuntime::operator()(
return out;
}
void StaticRuntime::set_inputs(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) {
if (!kwargs.empty()) {
// This is not ideal
TORCH_CHECK(
static_module_.schema(),
"Schema is not available. Consider creating the Static Runtime "
"with StaticModule(const torch::jit::Module& m) instead.");
std::vector<c10::IValue> stack;
stack.reserve(inputs_.size());
if (static_module_.first_input_is_self()) {
stack.emplace_back(static_module_.module()._ivalue());
}
stack.insert(stack.end(), args.begin(), args.end());
static_module_.schema()->checkAndNormalizeInputs(stack, kwargs);
DCHECK_EQ(inputs_.size(), stack.size());
for (const auto i : c10::irange(stack.size())) {
Input(i) = std::move(stack[i]);
}
} else {
if (static_module_.first_input_is_self()) {
Input(0) = static_module_.module()._ivalue();
DCHECK_EQ(inputs_.size(), args.size() + 1);
for (const auto i : c10::irange(args.size())) {
Input(i + 1) = args[i];
}
} else {
DCHECK_EQ(inputs_.size(), args.size());
for (const auto i : c10::irange(args.size())) {
Input(i) = args[i];
}
}
}
}
c10::IValue StaticRuntime::operator()(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) {
@ -701,27 +760,13 @@ c10::IValue StaticRuntime::operator()(
planner_->allocate();
}
if (!kwargs.empty()) {
// This is not ideal
TORCH_CHECK(
static_module_.schema(),
"Schema is not available. Consider creating the Static Runtime "
"with StaticModule(const torch::jit::Module& m) instead.");
std::vector<c10::IValue> s = args;
static_module_.schema()->checkAndNormalizeInputs(s, kwargs);
for (const auto i : c10::irange(s.size())) {
Input(i) = std::move(s[i]);
}
} else {
for (const auto i : c10::irange(args.size())) {
Input(i) = args[i];
}
}
set_inputs(args, kwargs);
// NB: before optimizing the order of execution, ensure that the
// memory optimization pass (LivenessMap) is
// aware of the new order!
for (auto& n : nodes_) {
// LOG(INFO) << "Running node: " << PrintNode(n.node());
n.run();
}
@ -739,9 +784,7 @@ c10::IValue StaticRuntime::operator()(
}
planner_->deallocate();
// clean up owning refs of input tensors
for (IValue& ival : inputs_) {
ival = IValue();
}
clean_up_input_ivalues();
}
// no need to keep references of outputs in static runtime anymore
@ -829,6 +872,10 @@ void StaticRuntime::benchmark(
<< "%)" << std::endl;
}
check_for_memory_leak();
#ifndef NDEBUG
display_nodes(args, kwargs);
#endif
}
float StaticRuntime::benchmark_model(
@ -906,16 +953,36 @@ void display_pnode_info(const ProcessedNode& pnode) {
}
}
void StaticRuntime::display_nodes(const std::vector<c10::IValue>& args) {
void StaticRuntime::display_nodes(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) {
c10::InferenceMode mode;
std::vector<IValue> stack(args);
for (size_t i = 0; i < stack.size(); i++) {
Input(i) = stack[i];
if (planner_) {
planner_->allocate();
}
set_inputs(args, kwargs);
for (auto& node : nodes_) {
node.run();
display_pnode_info(node);
}
if (static_module_.opts().cleanup_activations) {
// MemoryPlanner is created after the first invocation of `run()`. This is
// done intentionally because MemoryPlanner uses `Tensor` sizes of the
// previous `run()` for memory planning of subsequent runs
if (!planner_) {
planner_ = std::make_unique<MemoryPlanner>(
this,
static_module_.values_share_same_storage(),
static_module_.external_values(),
static_module_.opts().enable_out_variant,
static_module_.opts().optimize_graph_output_memory);
}
planner_->deallocate();
// clean up owning refs of input tensors
clean_up_input_ivalues();
}
}
StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
@ -934,18 +1001,9 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
// setup time
caffe2::Timer timer;
std::vector<IValue> stack(args);
if (!kwargs.empty()) {
// This is not ideal
TORCH_CHECK(
static_module_.schema(),
"Schema is not available. Consider creating the Static Runtime "
"with StaticModule(const torch::jit::Module& m) instead.");
static_module_.schema()->checkAndNormalizeInputs(stack, kwargs);
}
for (const auto i : c10::irange(stack.size())) {
Input(i) = stack[i];
}
set_inputs(args, kwargs);
results.setup_time = timer.MilliSeconds();
// warmup runs
@ -957,9 +1015,9 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
// main runs
for (const auto k : c10::irange(main_runs)) {
(void)k; // Suppress unused variable warning
for (const auto i : c10::irange(stack.size())) {
Input(i) = stack[i];
}
set_inputs(args, kwargs);
timer.Start();
if (planner_) {
planner_->allocate();
@ -985,9 +1043,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
}
planner_->deallocate();
// clean up owning refs of input tensors
for (IValue& ival : inputs_) {
ival = IValue();
}
clean_up_input_ivalues();
}
millis = timer.MilliSeconds();
results.memory_dealloc_time += millis;
@ -1283,16 +1339,11 @@ ProcessedNode::ProcessedNode(
VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
return;
}
if (!fn_ && mayRunNatively(node)) {
native_fn_ = getNativeOperation(node);
if (native_fn_) {
VLOG(1) << "Switch to native impl for node: " << PrintNode(node);
return;
}
if (!fn_ && (native_fn_ = getNativeOperation(node))) {
VLOG(1) << "Switch to native impl for node: " << PrintNode(node);
return;
}
if (node->kind() != prim::ListConstruct &&
node->kind() != prim::TupleConstruct &&
node->kind() != prim::DictConstruct && node->kind() != prim::ListUnpack) {
{
const Operator& op = node->getOperator();
TORCH_CHECK(op.hasOperation());
op_ = op.getOperation(node);

View File

@ -92,9 +92,8 @@ class TORCH_API StaticModule {
private:
explicit StaticModule(
std::pair<
std::shared_ptr<torch::jit::Graph>,
c10::optional<c10::FunctionSchema>> graph_and_schema,
std::pair<std::shared_ptr<torch::jit::Graph>, std::shared_ptr<Module>>
graph_and_module,
const StaticModuleOptions& opts);
// for <kind, idx>
@ -116,6 +115,10 @@ class TORCH_API StaticModule {
return *graph_;
}
const Module& module() const {
return *module_;
}
const StaticModuleOptions& opts() const;
size_t num_inputs() const;
size_t num_outputs() const;
@ -149,11 +152,17 @@ class TORCH_API StaticModule {
return external_values_;
}
bool first_input_is_self() const {
return first_input_is_self_;
}
StaticRuntime& runtime();
private:
StaticModuleOptions opts_;
bool first_input_is_self_{false};
std::shared_ptr<torch::jit::Graph> graph_;
std::shared_ptr<torch::jit::Module> module_;
c10::optional<c10::FunctionSchema> schema_;
std::unique_ptr<StaticRuntime> cached_runtime_;
@ -188,7 +197,9 @@ class TORCH_API StaticRuntime {
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs);
void display_nodes(const std::vector<c10::IValue>& args);
void display_nodes(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs);
void benchmark(
const std::vector<c10::IValue>& args,
@ -254,6 +265,18 @@ class TORCH_API StaticRuntime {
void check_for_memory_leak(bool output_returned = true);
private:
// helper method for copying input args/kwargs into inputs_
void set_inputs(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs);
// clean up owning refs of input IValues
void clean_up_input_ivalues() {
for (IValue& ival : inputs_) {
ival = IValue();
}
}
// Memory planning is only enabled if sm->opts().cleanup_activations is true.
// Otherwise, the memory used by activations is cached inside the static
// runtime.

View File

@ -189,28 +189,6 @@ std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n) {
return nullptr;
}
// TODO: expand to include all view producing ops, mostly in
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
bool mayRunNatively(Node* n) {
// In alphabetical order
const static std::unordered_set<std::string> native_nodes{
"aten::flatten",
"aten::reshape",
"aten::slice",
"aten::transpose",
"aten::to",
"prim::ListConstruct",
"prim::ListUnpack",
"prim::TupleConstruct",
"prim::DictConstruct",
"aten::__getitem__"};
auto str = std::string(n->kind().toQualString());
if (!native_nodes.count(str)) {
return false;
}
return true;
}
// Expensive check, use sparingly.
// This is needed to make sure that we only switch to out variants for the
// supported overloads, which is checked in the `Generate` step in
@ -1302,6 +1280,24 @@ std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
p_node->Output(0) = in0_t.clone();
}
};
} else if (n->kind() == prim::GetAttr) {
return [](ProcessedNode* p_node) {
auto module = p_node->Input(0).toObject();
Node* node = p_node->node();
const auto type = node->input()->type()->expect<ClassType>();
const auto& field = node->s(attr::name);
const auto slot = type->getAttributeSlot(field);
p_node->Output(0) = module->getSlot(slot);
};
} else if (n->kind() == prim::SetAttr) {
return [](ProcessedNode* p_node) {
auto module = p_node->Input(0).toObject();
Node* node = p_node->node();
const auto type = node->inputs()[0]->type()->expect<ClassType>();
const auto& field = node->s(attr::name);
const auto slot = type->getAttributeSlot(field);
module->setSlot(slot, p_node->Input(1));
};
}
return nullptr;
}

View File

@ -124,7 +124,6 @@ bool isOptimizableContainerType(Node* n);
std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n);
bool mayRunNatively(Node* n);
std::function<void(ProcessedNode*)> getNativeOperation(Node* n);
inline std::string PrintNode(const Node* node) {

View File

@ -366,9 +366,6 @@ void FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph>& graph) {
}
TORCH_LIBRARY_FRAGMENT(static_runtime, m) {
m.def("static_runtime::pure_inputs() -> Tensor", []() -> at::Tensor {
return at::randn({1});
});
m.def("static_runtime::permute_copy(Tensor self, int[] dims) -> Tensor");
m.def(
"static_runtime::reshape_copy(Tensor(a) self, int[] shape) -> Tensor(a)");
@ -386,24 +383,10 @@ bool HasInplaceOp(std::shared_ptr<Graph>& graph, const AliasDb& alias_db) {
return HasInplaceOp(graph->block(), alias_db);
}
void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph) {
auto* fake_input =
graph->insert(Symbol::fromQualString("static_runtime::pure_inputs"), {});
fake_input->node()->moveBefore(*graph->nodes().begin());
std::vector<std::pair<Value*, Use>> old_inputs;
for (auto* input : graph->inputs()) {
for (const auto& use : input->uses()) {
old_inputs.emplace_back(std::make_pair(input, use));
}
input->replaceAllUsesWith(fake_input);
}
void ReplaceWithCopy(
std::shared_ptr<torch::jit::Graph>& graph,
bool outputs_are_immutable) {
AliasDb db(graph);
for (const auto& p : old_inputs) {
p.second.user->replaceInput(p.second.offset, p.first);
}
fake_input->node()->destroy();
const std::map<c10::Symbol, c10::Symbol> supported = {
#ifdef FBCODE_CAFFE2
@ -474,7 +457,7 @@ void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph) {
}
auto* out = n->output();
if (db.mayContainAlias({out}, graph->outputs())) {
if (!outputs_are_immutable && db.mayContainAlias({out}, graph->outputs())) {
continue;
}
auto* new_node = graph->create(new_symbol, n->outputs().size());

View File

@ -7,7 +7,11 @@ TORCH_API void FuseInferenceOpsForSparseNN(
std::shared_ptr<torch::jit::Graph>& graph);
TORCH_API void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph);
TORCH_API void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph);
// If outputs_are_immutable is set to false, don't replace the view ops that
// produce aliases of graph outputs with the copy version.
TORCH_API void ReplaceWithCopy(
std::shared_ptr<torch::jit::Graph>& graph,
bool outputs_are_immutable = true);
TORCH_API bool HasInplaceOp(
std::shared_ptr<Graph>& graph,