[Static Runtime] Support prim::GetAttr/SetAttr (#61505)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61505

The handling of `self` in static runtime was previously incorrect. This diff fixed that issue, since self is essential to prim::GetAttr/SetAttr. After all, most of the time we're getting and setting attributes from self, the torch script module.

Reviewed By: ajyu

Differential Revision: D29350173

fbshipit-source-id: 6e62add4cda517ef8cd6c315d4cb0595e7d531fb
This commit is contained in:
Hao Lu 2021-07-10 14:04:48 -07:00 committed by Facebook GitHub Bot
parent f291b1899f
commit ccd0977060
7 changed files with 262 additions and 128 deletions

View File

@ -1,11 +1,11 @@
import unittest
from typing import Dict, Optional
import numpy as np import numpy as np
import torch import torch
import unittest
from torch import nn from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_utils import TestCase, run_tests
from typing import Dict, Optional
class StaticModule: class StaticModule:
def __init__(self, scripted): def __init__(self, scripted):
@ -30,7 +30,9 @@ class StaticModule:
) )
def linear_shim(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: def linear_shim(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
output = input.matmul(weight.t()) output = input.matmul(weight.t())
if bias is not None: if bias is not None:
output += bias output += bias
@ -107,7 +109,8 @@ def trivial_graph(a, b, c):
s = torch.tensor([[3, 3], [3, 3]]) s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s return a + b * c + s
def loop_graph(a, b, iters : int):
def loop_graph(a, b, iters: int):
c = a + b * 2 c = a + b * 2
for i in range(iters): for i in range(iters):
c = c + b c = c + b
@ -115,14 +118,50 @@ def loop_graph(a, b, iters : int):
c -= a c -= a
return c return c
def output_graph(a, b, c, iters : int):
def output_graph(a, b, c, iters: int):
s = torch.tensor([[3, 3], [3, 3]]) s = torch.tensor([[3, 3], [3, 3]])
k = a + b * c + s k = a + b * c + s
d : Dict[int, torch.Tensor] = {} d: Dict[int, torch.Tensor] = {}
for i in range(iters): for i in range(iters):
d[i] = k + i d[i] = k + i
return d return d
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = 11
self.b = 2
def forward(self, x):
return self.a + self.b + x
class SubModule2(nn.Module):
def __init__(self):
super(SubModule2, self).__init__()
self.a = 12
self.b = 2
def forward(self, x):
self.b = 30
return self.a + self.b + x
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub1 = SubModule()
self.sub2 = SubModule2()
self.a = 3
self.b = 4
def forward(self, x):
self.b = 20
return self.sub1(x) + self.a + self.b + self.sub2(x)
class TestStaticModule(TestCase): class TestStaticModule(TestCase):
def test_multihead_attention_layer(self): def test_multihead_attention_layer(self):
HID_DIM = 256 HID_DIM = 256
@ -220,6 +259,46 @@ class TestStaticModule(TestCase):
o_test = tg_a(s)[0] o_test = tg_a(s)[0]
torch.testing.assert_allclose(o_ref, o_test) torch.testing.assert_allclose(o_ref, o_test)
def test_attr(self):
"""
TorchScript IR of TestModule() after freezing:
graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule,
%x.1 : Tensor):
%18 : int = prim::Constant[value=30]()
%30 : int = prim::Constant[value=13]()
%3 : int = prim::Constant[value=20]()
%2 : int = prim::Constant[value=1]()
%self.sub2.a : int = prim::Constant[value=12]()
%self.a : int = prim::Constant[value=3]()
= prim::SetAttr[name="b"](%self, %3)
%17 : Tensor = aten::add(%x.1, %30, %2)
%7 : Tensor = aten::add(%17, %self.a, %2)
%b.1 : int = prim::GetAttr[name="b"](%self)
%9 : Tensor = aten::add(%7, %b.1, %2)
%sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self)
= prim::SetAttr[name="b"](%sub2, %18)
%b : int = prim::GetAttr[name="b"](%sub2)
%22 : int = aten::add(%self.sub2.a, %b)
%23 : Tensor = aten::add(%x.1, %22, %2)
%12 : Tensor = aten::add(%9, %23, %2)
return (%12)
"""
# test prim::SetAttr and prim::GetAttr impl in Static Runtime
m = TestModule()
m.eval()
input = torch.randn(2, 2)
output_s = m.forward(input)
ms = torch.jit.script(m)
sm = StaticModule(ms)
output_sm = sm(input)[0]
torch.testing.assert_allclose(output_s, output_sm)
sm.benchmark([input], {}, 2, 2)
sm.benchmark_individual_ops([input], {}, 2, 2)
sm.benchmark([], {"x": input}, 2, 2)
sm.benchmark_individual_ops([], {"x": input}, 2, 2)
@unittest.skip("Temporarily disabled") @unittest.skip("Temporarily disabled")
def test_fusion_trivial_graph(self): def test_fusion_trivial_graph(self):
s = torch.full((2, 2), 2) s = torch.full((2, 2), 2)
@ -281,6 +360,5 @@ class TestStaticModule(TestCase):
torch.testing.assert_allclose(o_ref[i], o_test[i]) torch.testing.assert_allclose(o_ref[i], o_test[i])
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -45,20 +45,29 @@ void OptimizeGraph(
ConstantPropagation(graph); ConstantPropagation(graph);
} }
void CheckGraphEligibility(const std::shared_ptr<torch::jit::Graph>& graph) { bool CheckGraphEligibility(const std::shared_ptr<torch::jit::Graph>& graph) {
for (auto n : graph->nodes()) { // check for sub-blocks
if (n->kind() == c10::Symbol::fromQualString("prim::GetAttr")) { bool can_support = true;
throw std::runtime_error("Cannot accelerate unfrozen graphs"); for (auto* node : graph->block()->nodes()) {
for (Block* sub_block : node->blocks()) {
VLOG(1) << "Found nested sub-blocks in graph at node: "
<< PrintNode(node);
can_support = false;
} }
} }
return can_support;
} }
// remove unused input 0 from graph // remove unused input 0 from graph
void RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) { bool RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
if (graph->inputs().at(0)->type()->is_module()) { if (graph->inputs().at(0)->type()->is_module()) {
TORCH_CHECK(!graph->inputs().at(0)->hasUses()); if (graph->inputs().at(0)->hasUses()) {
return false;
}
graph->eraseInput(0); graph->eraseInput(0);
} }
return true;
} }
// remove "self" from function schema // remove "self" from function schema
@ -443,12 +452,12 @@ GenerateSameStorageValues(
void PrepareGraphForStaticModule( void PrepareGraphForStaticModule(
std::shared_ptr<torch::jit::Graph> graph, std::shared_ptr<torch::jit::Graph> graph,
const StaticModuleOptions& opts) { const StaticModuleOptions& opts) {
CheckGraphEligibility(graph); // TODO: call CheckGraphEligibility before trying to enable static runtime
TORCH_CHECK(CheckGraphEligibility(graph));
OptimizeGraph(graph, opts); OptimizeGraph(graph, opts);
RemoveSelfFromGraphInput(graph);
} }
std::pair<std::shared_ptr<Graph>, c10::optional<c10::FunctionSchema>> std::pair<std::shared_ptr<Graph>, std::shared_ptr<Module>>
PrepareForStaticModule( PrepareForStaticModule(
const torch::jit::Module& m, const torch::jit::Module& m,
const StaticModuleOptions& opts) { const StaticModuleOptions& opts) {
@ -461,22 +470,23 @@ PrepareForStaticModule(
auto module = m.copy(); auto module = m.copy();
module.eval(); module.eval();
module = freeze_module(module); auto module_ptr = std::make_shared<Module>(freeze_module(module));
Method method = module.get_method("forward"); Method method = module_ptr->get_method("forward");
auto graph = module.get_method("forward").graph(); auto graph = module_ptr->get_method("forward").graph();
// graph->dump();
PrepareGraphForStaticModule(graph, opts); PrepareGraphForStaticModule(graph, opts);
c10::FunctionSchema s = RemoveSelfFromSchema(method.function().getSchema()); return std::make_pair(graph, module_ptr);
return std::make_pair(graph, s);
} }
std::pair<std::shared_ptr<Graph>, c10::optional<c10::FunctionSchema>> std::pair<std::shared_ptr<Graph>, std::shared_ptr<Module>>
PrepareForStaticModule( PrepareForStaticModule(
std::shared_ptr<torch::jit::Graph> graph, std::shared_ptr<torch::jit::Graph> graph,
const StaticModuleOptions& opts) { const StaticModuleOptions& opts) {
PrepareGraphForStaticModule(graph, opts); PrepareGraphForStaticModule(graph, opts);
return std::make_pair(graph, c10::nullopt); return std::make_pair(graph, nullptr);
} }
} // namespace } // namespace
@ -492,13 +502,12 @@ StaticModule::StaticModule(
: StaticModule(PrepareForStaticModule(m, opts), opts) {} : StaticModule(PrepareForStaticModule(m, opts), opts) {}
StaticModule::StaticModule( StaticModule::StaticModule(
std::pair< std::pair<std::shared_ptr<torch::jit::Graph>, std::shared_ptr<Module>>
std::shared_ptr<torch::jit::Graph>, graph_and_module,
c10::optional<c10::FunctionSchema>> graph_and_schema,
const StaticModuleOptions& opts) const StaticModuleOptions& opts)
: opts_(opts), : opts_(opts),
graph_(std::move(graph_and_schema.first)), graph_(std::move(graph_and_module.first)),
schema_(std::move(graph_and_schema.second)) { module_(std::move(graph_and_module.second)) {
// check opt flags // check opt flags
if (opts.optimize_graph_output_memory) { if (opts.optimize_graph_output_memory) {
TORCH_CHECK( TORCH_CHECK(
@ -511,6 +520,18 @@ StaticModule::StaticModule(
"When optimize_memory is true, enable_out_variant must be set to true"); "When optimize_memory is true, enable_out_variant must be set to true");
} }
// handle schema
if (module_) {
Method method = module_->get_method("forward");
schema_ = method.function().getSchema();
if (RemoveSelfFromGraphInput(graph_)) {
schema_ = RemoveSelfFromSchema(method.function().getSchema());
} else {
first_input_is_self_ = true;
schema_ = method.function().getSchema();
}
}
// map Value* to IValue (from inputs or prim::Constant) or null // map Value* to IValue (from inputs or prim::Constant) or null
std::unordered_map<Value*, IValue*> value_to_ivalue; std::unordered_map<Value*, IValue*> value_to_ivalue;
// map Value* to its SSA definition IR // map Value* to its SSA definition IR
@ -620,6 +641,7 @@ StaticRuntime::StaticRuntime(const StaticModule& sm) : static_module_(sm) {
// NB: create unchanging std::vector<IValue>s we can reference // NB: create unchanging std::vector<IValue>s we can reference
inputs_.resize(sm.num_inputs()); inputs_.resize(sm.num_inputs());
nodes_.resize(sm.nodes().size()); nodes_.resize(sm.nodes().size());
for (const auto idx : c10::irange(sm.nodes().size())) { for (const auto idx : c10::irange(sm.nodes().size())) {
const auto& n_ref = sm.nodes()[idx]; const auto& n_ref = sm.nodes()[idx];
nodes_[idx] = n_ref; // copy the node nodes_[idx] = n_ref; // copy the node
@ -688,6 +710,43 @@ std::vector<at::Tensor> StaticRuntime::operator()(
return out; return out;
} }
void StaticRuntime::set_inputs(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) {
if (!kwargs.empty()) {
// This is not ideal
TORCH_CHECK(
static_module_.schema(),
"Schema is not available. Consider creating the Static Runtime "
"with StaticModule(const torch::jit::Module& m) instead.");
std::vector<c10::IValue> stack;
stack.reserve(inputs_.size());
if (static_module_.first_input_is_self()) {
stack.emplace_back(static_module_.module()._ivalue());
}
stack.insert(stack.end(), args.begin(), args.end());
static_module_.schema()->checkAndNormalizeInputs(stack, kwargs);
DCHECK_EQ(inputs_.size(), stack.size());
for (const auto i : c10::irange(stack.size())) {
Input(i) = std::move(stack[i]);
}
} else {
if (static_module_.first_input_is_self()) {
Input(0) = static_module_.module()._ivalue();
DCHECK_EQ(inputs_.size(), args.size() + 1);
for (const auto i : c10::irange(args.size())) {
Input(i + 1) = args[i];
}
} else {
DCHECK_EQ(inputs_.size(), args.size());
for (const auto i : c10::irange(args.size())) {
Input(i) = args[i];
}
}
}
}
c10::IValue StaticRuntime::operator()( c10::IValue StaticRuntime::operator()(
const std::vector<c10::IValue>& args, const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) { const std::unordered_map<std::string, c10::IValue>& kwargs) {
@ -701,27 +760,13 @@ c10::IValue StaticRuntime::operator()(
planner_->allocate(); planner_->allocate();
} }
if (!kwargs.empty()) { set_inputs(args, kwargs);
// This is not ideal
TORCH_CHECK(
static_module_.schema(),
"Schema is not available. Consider creating the Static Runtime "
"with StaticModule(const torch::jit::Module& m) instead.");
std::vector<c10::IValue> s = args;
static_module_.schema()->checkAndNormalizeInputs(s, kwargs);
for (const auto i : c10::irange(s.size())) {
Input(i) = std::move(s[i]);
}
} else {
for (const auto i : c10::irange(args.size())) {
Input(i) = args[i];
}
}
// NB: before optimizing the order of execution, ensure that the // NB: before optimizing the order of execution, ensure that the
// memory optimization pass (LivenessMap) is // memory optimization pass (LivenessMap) is
// aware of the new order! // aware of the new order!
for (auto& n : nodes_) { for (auto& n : nodes_) {
// LOG(INFO) << "Running node: " << PrintNode(n.node());
n.run(); n.run();
} }
@ -739,9 +784,7 @@ c10::IValue StaticRuntime::operator()(
} }
planner_->deallocate(); planner_->deallocate();
// clean up owning refs of input tensors // clean up owning refs of input tensors
for (IValue& ival : inputs_) { clean_up_input_ivalues();
ival = IValue();
}
} }
// no need to keep references of outputs in static runtime anymore // no need to keep references of outputs in static runtime anymore
@ -829,6 +872,10 @@ void StaticRuntime::benchmark(
<< "%)" << std::endl; << "%)" << std::endl;
} }
check_for_memory_leak(); check_for_memory_leak();
#ifndef NDEBUG
display_nodes(args, kwargs);
#endif
} }
float StaticRuntime::benchmark_model( float StaticRuntime::benchmark_model(
@ -906,16 +953,36 @@ void display_pnode_info(const ProcessedNode& pnode) {
} }
} }
void StaticRuntime::display_nodes(const std::vector<c10::IValue>& args) { void StaticRuntime::display_nodes(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) {
c10::InferenceMode mode; c10::InferenceMode mode;
std::vector<IValue> stack(args); if (planner_) {
for (size_t i = 0; i < stack.size(); i++) { planner_->allocate();
Input(i) = stack[i];
} }
set_inputs(args, kwargs);
for (auto& node : nodes_) { for (auto& node : nodes_) {
node.run(); node.run();
display_pnode_info(node); display_pnode_info(node);
} }
if (static_module_.opts().cleanup_activations) {
// MemoryPlanner is created after the first invocation of `run()`. This is
// done intentionally because MemoryPlanner uses `Tensor` sizes of the
// previous `run()` for memory planning of subsequent runs
if (!planner_) {
planner_ = std::make_unique<MemoryPlanner>(
this,
static_module_.values_share_same_storage(),
static_module_.external_values(),
static_module_.opts().enable_out_variant,
static_module_.opts().optimize_graph_output_memory);
}
planner_->deallocate();
// clean up owning refs of input tensors
clean_up_input_ivalues();
}
} }
StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
@ -934,18 +1001,9 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
// setup time // setup time
caffe2::Timer timer; caffe2::Timer timer;
std::vector<IValue> stack(args);
if (!kwargs.empty()) { set_inputs(args, kwargs);
// This is not ideal
TORCH_CHECK(
static_module_.schema(),
"Schema is not available. Consider creating the Static Runtime "
"with StaticModule(const torch::jit::Module& m) instead.");
static_module_.schema()->checkAndNormalizeInputs(stack, kwargs);
}
for (const auto i : c10::irange(stack.size())) {
Input(i) = stack[i];
}
results.setup_time = timer.MilliSeconds(); results.setup_time = timer.MilliSeconds();
// warmup runs // warmup runs
@ -957,9 +1015,9 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
// main runs // main runs
for (const auto k : c10::irange(main_runs)) { for (const auto k : c10::irange(main_runs)) {
(void)k; // Suppress unused variable warning (void)k; // Suppress unused variable warning
for (const auto i : c10::irange(stack.size())) {
Input(i) = stack[i]; set_inputs(args, kwargs);
}
timer.Start(); timer.Start();
if (planner_) { if (planner_) {
planner_->allocate(); planner_->allocate();
@ -985,9 +1043,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
} }
planner_->deallocate(); planner_->deallocate();
// clean up owning refs of input tensors // clean up owning refs of input tensors
for (IValue& ival : inputs_) { clean_up_input_ivalues();
ival = IValue();
}
} }
millis = timer.MilliSeconds(); millis = timer.MilliSeconds();
results.memory_dealloc_time += millis; results.memory_dealloc_time += millis;
@ -1283,16 +1339,11 @@ ProcessedNode::ProcessedNode(
VLOG(1) << "Switch to out variant for node: " << PrintNode(node); VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
return; return;
} }
if (!fn_ && mayRunNatively(node)) { if (!fn_ && (native_fn_ = getNativeOperation(node))) {
native_fn_ = getNativeOperation(node);
if (native_fn_) {
VLOG(1) << "Switch to native impl for node: " << PrintNode(node); VLOG(1) << "Switch to native impl for node: " << PrintNode(node);
return; return;
} }
} {
if (node->kind() != prim::ListConstruct &&
node->kind() != prim::TupleConstruct &&
node->kind() != prim::DictConstruct && node->kind() != prim::ListUnpack) {
const Operator& op = node->getOperator(); const Operator& op = node->getOperator();
TORCH_CHECK(op.hasOperation()); TORCH_CHECK(op.hasOperation());
op_ = op.getOperation(node); op_ = op.getOperation(node);

View File

@ -92,9 +92,8 @@ class TORCH_API StaticModule {
private: private:
explicit StaticModule( explicit StaticModule(
std::pair< std::pair<std::shared_ptr<torch::jit::Graph>, std::shared_ptr<Module>>
std::shared_ptr<torch::jit::Graph>, graph_and_module,
c10::optional<c10::FunctionSchema>> graph_and_schema,
const StaticModuleOptions& opts); const StaticModuleOptions& opts);
// for <kind, idx> // for <kind, idx>
@ -116,6 +115,10 @@ class TORCH_API StaticModule {
return *graph_; return *graph_;
} }
const Module& module() const {
return *module_;
}
const StaticModuleOptions& opts() const; const StaticModuleOptions& opts() const;
size_t num_inputs() const; size_t num_inputs() const;
size_t num_outputs() const; size_t num_outputs() const;
@ -149,11 +152,17 @@ class TORCH_API StaticModule {
return external_values_; return external_values_;
} }
bool first_input_is_self() const {
return first_input_is_self_;
}
StaticRuntime& runtime(); StaticRuntime& runtime();
private: private:
StaticModuleOptions opts_; StaticModuleOptions opts_;
bool first_input_is_self_{false};
std::shared_ptr<torch::jit::Graph> graph_; std::shared_ptr<torch::jit::Graph> graph_;
std::shared_ptr<torch::jit::Module> module_;
c10::optional<c10::FunctionSchema> schema_; c10::optional<c10::FunctionSchema> schema_;
std::unique_ptr<StaticRuntime> cached_runtime_; std::unique_ptr<StaticRuntime> cached_runtime_;
@ -188,7 +197,9 @@ class TORCH_API StaticRuntime {
const std::vector<c10::IValue>& args, const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs); const std::unordered_map<std::string, c10::IValue>& kwargs);
void display_nodes(const std::vector<c10::IValue>& args); void display_nodes(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs);
void benchmark( void benchmark(
const std::vector<c10::IValue>& args, const std::vector<c10::IValue>& args,
@ -254,6 +265,18 @@ class TORCH_API StaticRuntime {
void check_for_memory_leak(bool output_returned = true); void check_for_memory_leak(bool output_returned = true);
private: private:
// helper method for copying input args/kwargs into inputs_
void set_inputs(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs);
// clean up owning refs of input IValues
void clean_up_input_ivalues() {
for (IValue& ival : inputs_) {
ival = IValue();
}
}
// Memory planning is only enabled if sm->opts().cleanup_activations is true. // Memory planning is only enabled if sm->opts().cleanup_activations is true.
// Otherwise, the memory used by activations is cached inside the static // Otherwise, the memory used by activations is cached inside the static
// runtime. // runtime.

View File

@ -189,28 +189,6 @@ std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n) {
return nullptr; return nullptr;
} }
// TODO: expand to include all view producing ops, mostly in
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
bool mayRunNatively(Node* n) {
// In alphabetical order
const static std::unordered_set<std::string> native_nodes{
"aten::flatten",
"aten::reshape",
"aten::slice",
"aten::transpose",
"aten::to",
"prim::ListConstruct",
"prim::ListUnpack",
"prim::TupleConstruct",
"prim::DictConstruct",
"aten::__getitem__"};
auto str = std::string(n->kind().toQualString());
if (!native_nodes.count(str)) {
return false;
}
return true;
}
// Expensive check, use sparingly. // Expensive check, use sparingly.
// This is needed to make sure that we only switch to out variants for the // This is needed to make sure that we only switch to out variants for the
// supported overloads, which is checked in the `Generate` step in // supported overloads, which is checked in the `Generate` step in
@ -1302,6 +1280,24 @@ std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
p_node->Output(0) = in0_t.clone(); p_node->Output(0) = in0_t.clone();
} }
}; };
} else if (n->kind() == prim::GetAttr) {
return [](ProcessedNode* p_node) {
auto module = p_node->Input(0).toObject();
Node* node = p_node->node();
const auto type = node->input()->type()->expect<ClassType>();
const auto& field = node->s(attr::name);
const auto slot = type->getAttributeSlot(field);
p_node->Output(0) = module->getSlot(slot);
};
} else if (n->kind() == prim::SetAttr) {
return [](ProcessedNode* p_node) {
auto module = p_node->Input(0).toObject();
Node* node = p_node->node();
const auto type = node->inputs()[0]->type()->expect<ClassType>();
const auto& field = node->s(attr::name);
const auto slot = type->getAttributeSlot(field);
module->setSlot(slot, p_node->Input(1));
};
} }
return nullptr; return nullptr;
} }

View File

@ -124,7 +124,6 @@ bool isOptimizableContainerType(Node* n);
std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n); std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n);
bool mayRunNatively(Node* n);
std::function<void(ProcessedNode*)> getNativeOperation(Node* n); std::function<void(ProcessedNode*)> getNativeOperation(Node* n);
inline std::string PrintNode(const Node* node) { inline std::string PrintNode(const Node* node) {

View File

@ -366,9 +366,6 @@ void FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph>& graph) {
} }
TORCH_LIBRARY_FRAGMENT(static_runtime, m) { TORCH_LIBRARY_FRAGMENT(static_runtime, m) {
m.def("static_runtime::pure_inputs() -> Tensor", []() -> at::Tensor {
return at::randn({1});
});
m.def("static_runtime::permute_copy(Tensor self, int[] dims) -> Tensor"); m.def("static_runtime::permute_copy(Tensor self, int[] dims) -> Tensor");
m.def( m.def(
"static_runtime::reshape_copy(Tensor(a) self, int[] shape) -> Tensor(a)"); "static_runtime::reshape_copy(Tensor(a) self, int[] shape) -> Tensor(a)");
@ -386,24 +383,10 @@ bool HasInplaceOp(std::shared_ptr<Graph>& graph, const AliasDb& alias_db) {
return HasInplaceOp(graph->block(), alias_db); return HasInplaceOp(graph->block(), alias_db);
} }
void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph) { void ReplaceWithCopy(
auto* fake_input = std::shared_ptr<torch::jit::Graph>& graph,
graph->insert(Symbol::fromQualString("static_runtime::pure_inputs"), {}); bool outputs_are_immutable) {
fake_input->node()->moveBefore(*graph->nodes().begin());
std::vector<std::pair<Value*, Use>> old_inputs;
for (auto* input : graph->inputs()) {
for (const auto& use : input->uses()) {
old_inputs.emplace_back(std::make_pair(input, use));
}
input->replaceAllUsesWith(fake_input);
}
AliasDb db(graph); AliasDb db(graph);
for (const auto& p : old_inputs) {
p.second.user->replaceInput(p.second.offset, p.first);
}
fake_input->node()->destroy();
const std::map<c10::Symbol, c10::Symbol> supported = { const std::map<c10::Symbol, c10::Symbol> supported = {
#ifdef FBCODE_CAFFE2 #ifdef FBCODE_CAFFE2
@ -474,7 +457,7 @@ void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph) {
} }
auto* out = n->output(); auto* out = n->output();
if (db.mayContainAlias({out}, graph->outputs())) { if (!outputs_are_immutable && db.mayContainAlias({out}, graph->outputs())) {
continue; continue;
} }
auto* new_node = graph->create(new_symbol, n->outputs().size()); auto* new_node = graph->create(new_symbol, n->outputs().size());

View File

@ -7,7 +7,11 @@ TORCH_API void FuseInferenceOpsForSparseNN(
std::shared_ptr<torch::jit::Graph>& graph); std::shared_ptr<torch::jit::Graph>& graph);
TORCH_API void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph); TORCH_API void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph);
TORCH_API void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph); // If outputs_are_immutable is set to false, don't replace the view ops that
// produce aliases of graph outputs with the copy version.
TORCH_API void ReplaceWithCopy(
std::shared_ptr<torch::jit::Graph>& graph,
bool outputs_are_immutable = true);
TORCH_API bool HasInplaceOp( TORCH_API bool HasInplaceOp(
std::shared_ptr<Graph>& graph, std::shared_ptr<Graph>& graph,