mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Static Runtime] Support prim::GetAttr/SetAttr (#61505)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61505 The handling of `self` in static runtime was previously incorrect. This diff fixed that issue, since self is essential to prim::GetAttr/SetAttr. After all, most of the time we're getting and setting attributes from self, the torch script module. Reviewed By: ajyu Differential Revision: D29350173 fbshipit-source-id: 6e62add4cda517ef8cd6c315d4cb0595e7d531fb
This commit is contained in:
parent
f291b1899f
commit
ccd0977060
|
|
@ -1,11 +1,11 @@
|
|||
import unittest
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
from torch import nn
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
class StaticModule:
|
||||
def __init__(self, scripted):
|
||||
|
|
@ -30,7 +30,9 @@ class StaticModule:
|
|||
)
|
||||
|
||||
|
||||
def linear_shim(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def linear_shim(
|
||||
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
output = input.matmul(weight.t())
|
||||
if bias is not None:
|
||||
output += bias
|
||||
|
|
@ -107,6 +109,7 @@ def trivial_graph(a, b, c):
|
|||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
return a + b * c + s
|
||||
|
||||
|
||||
def loop_graph(a, b, iters: int):
|
||||
c = a + b * 2
|
||||
for i in range(iters):
|
||||
|
|
@ -115,6 +118,7 @@ def loop_graph(a, b, iters : int):
|
|||
c -= a
|
||||
return c
|
||||
|
||||
|
||||
def output_graph(a, b, c, iters: int):
|
||||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
k = a + b * c + s
|
||||
|
|
@ -123,6 +127,41 @@ def output_graph(a, b, c, iters : int):
|
|||
d[i] = k + i
|
||||
return d
|
||||
|
||||
|
||||
class SubModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(SubModule, self).__init__()
|
||||
self.a = 11
|
||||
self.b = 2
|
||||
|
||||
def forward(self, x):
|
||||
return self.a + self.b + x
|
||||
|
||||
|
||||
class SubModule2(nn.Module):
|
||||
def __init__(self):
|
||||
super(SubModule2, self).__init__()
|
||||
self.a = 12
|
||||
self.b = 2
|
||||
|
||||
def forward(self, x):
|
||||
self.b = 30
|
||||
return self.a + self.b + x
|
||||
|
||||
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(TestModule, self).__init__()
|
||||
self.sub1 = SubModule()
|
||||
self.sub2 = SubModule2()
|
||||
self.a = 3
|
||||
self.b = 4
|
||||
|
||||
def forward(self, x):
|
||||
self.b = 20
|
||||
return self.sub1(x) + self.a + self.b + self.sub2(x)
|
||||
|
||||
|
||||
class TestStaticModule(TestCase):
|
||||
def test_multihead_attention_layer(self):
|
||||
HID_DIM = 256
|
||||
|
|
@ -220,6 +259,46 @@ class TestStaticModule(TestCase):
|
|||
o_test = tg_a(s)[0]
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
def test_attr(self):
|
||||
"""
|
||||
TorchScript IR of TestModule() after freezing:
|
||||
graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule,
|
||||
%x.1 : Tensor):
|
||||
%18 : int = prim::Constant[value=30]()
|
||||
%30 : int = prim::Constant[value=13]()
|
||||
%3 : int = prim::Constant[value=20]()
|
||||
%2 : int = prim::Constant[value=1]()
|
||||
%self.sub2.a : int = prim::Constant[value=12]()
|
||||
%self.a : int = prim::Constant[value=3]()
|
||||
= prim::SetAttr[name="b"](%self, %3)
|
||||
%17 : Tensor = aten::add(%x.1, %30, %2)
|
||||
%7 : Tensor = aten::add(%17, %self.a, %2)
|
||||
%b.1 : int = prim::GetAttr[name="b"](%self)
|
||||
%9 : Tensor = aten::add(%7, %b.1, %2)
|
||||
%sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self)
|
||||
= prim::SetAttr[name="b"](%sub2, %18)
|
||||
%b : int = prim::GetAttr[name="b"](%sub2)
|
||||
%22 : int = aten::add(%self.sub2.a, %b)
|
||||
%23 : Tensor = aten::add(%x.1, %22, %2)
|
||||
%12 : Tensor = aten::add(%9, %23, %2)
|
||||
return (%12)
|
||||
"""
|
||||
# test prim::SetAttr and prim::GetAttr impl in Static Runtime
|
||||
m = TestModule()
|
||||
|
||||
m.eval()
|
||||
input = torch.randn(2, 2)
|
||||
output_s = m.forward(input)
|
||||
|
||||
ms = torch.jit.script(m)
|
||||
sm = StaticModule(ms)
|
||||
output_sm = sm(input)[0]
|
||||
torch.testing.assert_allclose(output_s, output_sm)
|
||||
sm.benchmark([input], {}, 2, 2)
|
||||
sm.benchmark_individual_ops([input], {}, 2, 2)
|
||||
sm.benchmark([], {"x": input}, 2, 2)
|
||||
sm.benchmark_individual_ops([], {"x": input}, 2, 2)
|
||||
|
||||
@unittest.skip("Temporarily disabled")
|
||||
def test_fusion_trivial_graph(self):
|
||||
s = torch.full((2, 2), 2)
|
||||
|
|
@ -281,6 +360,5 @@ class TestStaticModule(TestCase):
|
|||
torch.testing.assert_allclose(o_ref[i], o_test[i])
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -45,20 +45,29 @@ void OptimizeGraph(
|
|||
ConstantPropagation(graph);
|
||||
}
|
||||
|
||||
void CheckGraphEligibility(const std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
for (auto n : graph->nodes()) {
|
||||
if (n->kind() == c10::Symbol::fromQualString("prim::GetAttr")) {
|
||||
throw std::runtime_error("Cannot accelerate unfrozen graphs");
|
||||
}
|
||||
bool CheckGraphEligibility(const std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
// check for sub-blocks
|
||||
bool can_support = true;
|
||||
for (auto* node : graph->block()->nodes()) {
|
||||
for (Block* sub_block : node->blocks()) {
|
||||
VLOG(1) << "Found nested sub-blocks in graph at node: "
|
||||
<< PrintNode(node);
|
||||
can_support = false;
|
||||
}
|
||||
}
|
||||
|
||||
return can_support;
|
||||
}
|
||||
|
||||
// remove unused input 0 from graph
|
||||
void RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
bool RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
if (graph->inputs().at(0)->type()->is_module()) {
|
||||
TORCH_CHECK(!graph->inputs().at(0)->hasUses());
|
||||
if (graph->inputs().at(0)->hasUses()) {
|
||||
return false;
|
||||
}
|
||||
graph->eraseInput(0);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// remove "self" from function schema
|
||||
|
|
@ -443,12 +452,12 @@ GenerateSameStorageValues(
|
|||
void PrepareGraphForStaticModule(
|
||||
std::shared_ptr<torch::jit::Graph> graph,
|
||||
const StaticModuleOptions& opts) {
|
||||
CheckGraphEligibility(graph);
|
||||
// TODO: call CheckGraphEligibility before trying to enable static runtime
|
||||
TORCH_CHECK(CheckGraphEligibility(graph));
|
||||
OptimizeGraph(graph, opts);
|
||||
RemoveSelfFromGraphInput(graph);
|
||||
}
|
||||
|
||||
std::pair<std::shared_ptr<Graph>, c10::optional<c10::FunctionSchema>>
|
||||
std::pair<std::shared_ptr<Graph>, std::shared_ptr<Module>>
|
||||
PrepareForStaticModule(
|
||||
const torch::jit::Module& m,
|
||||
const StaticModuleOptions& opts) {
|
||||
|
|
@ -461,22 +470,23 @@ PrepareForStaticModule(
|
|||
auto module = m.copy();
|
||||
module.eval();
|
||||
|
||||
module = freeze_module(module);
|
||||
auto module_ptr = std::make_shared<Module>(freeze_module(module));
|
||||
|
||||
Method method = module.get_method("forward");
|
||||
auto graph = module.get_method("forward").graph();
|
||||
Method method = module_ptr->get_method("forward");
|
||||
auto graph = module_ptr->get_method("forward").graph();
|
||||
|
||||
// graph->dump();
|
||||
PrepareGraphForStaticModule(graph, opts);
|
||||
|
||||
c10::FunctionSchema s = RemoveSelfFromSchema(method.function().getSchema());
|
||||
return std::make_pair(graph, s);
|
||||
return std::make_pair(graph, module_ptr);
|
||||
}
|
||||
|
||||
std::pair<std::shared_ptr<Graph>, c10::optional<c10::FunctionSchema>>
|
||||
std::pair<std::shared_ptr<Graph>, std::shared_ptr<Module>>
|
||||
PrepareForStaticModule(
|
||||
std::shared_ptr<torch::jit::Graph> graph,
|
||||
const StaticModuleOptions& opts) {
|
||||
PrepareGraphForStaticModule(graph, opts);
|
||||
return std::make_pair(graph, c10::nullopt);
|
||||
return std::make_pair(graph, nullptr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
@ -492,13 +502,12 @@ StaticModule::StaticModule(
|
|||
: StaticModule(PrepareForStaticModule(m, opts), opts) {}
|
||||
|
||||
StaticModule::StaticModule(
|
||||
std::pair<
|
||||
std::shared_ptr<torch::jit::Graph>,
|
||||
c10::optional<c10::FunctionSchema>> graph_and_schema,
|
||||
std::pair<std::shared_ptr<torch::jit::Graph>, std::shared_ptr<Module>>
|
||||
graph_and_module,
|
||||
const StaticModuleOptions& opts)
|
||||
: opts_(opts),
|
||||
graph_(std::move(graph_and_schema.first)),
|
||||
schema_(std::move(graph_and_schema.second)) {
|
||||
graph_(std::move(graph_and_module.first)),
|
||||
module_(std::move(graph_and_module.second)) {
|
||||
// check opt flags
|
||||
if (opts.optimize_graph_output_memory) {
|
||||
TORCH_CHECK(
|
||||
|
|
@ -511,6 +520,18 @@ StaticModule::StaticModule(
|
|||
"When optimize_memory is true, enable_out_variant must be set to true");
|
||||
}
|
||||
|
||||
// handle schema
|
||||
if (module_) {
|
||||
Method method = module_->get_method("forward");
|
||||
schema_ = method.function().getSchema();
|
||||
if (RemoveSelfFromGraphInput(graph_)) {
|
||||
schema_ = RemoveSelfFromSchema(method.function().getSchema());
|
||||
} else {
|
||||
first_input_is_self_ = true;
|
||||
schema_ = method.function().getSchema();
|
||||
}
|
||||
}
|
||||
|
||||
// map Value* to IValue (from inputs or prim::Constant) or null
|
||||
std::unordered_map<Value*, IValue*> value_to_ivalue;
|
||||
// map Value* to its SSA definition IR
|
||||
|
|
@ -620,6 +641,7 @@ StaticRuntime::StaticRuntime(const StaticModule& sm) : static_module_(sm) {
|
|||
// NB: create unchanging std::vector<IValue>s we can reference
|
||||
inputs_.resize(sm.num_inputs());
|
||||
nodes_.resize(sm.nodes().size());
|
||||
|
||||
for (const auto idx : c10::irange(sm.nodes().size())) {
|
||||
const auto& n_ref = sm.nodes()[idx];
|
||||
nodes_[idx] = n_ref; // copy the node
|
||||
|
|
@ -688,6 +710,43 @@ std::vector<at::Tensor> StaticRuntime::operator()(
|
|||
return out;
|
||||
}
|
||||
|
||||
void StaticRuntime::set_inputs(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs) {
|
||||
if (!kwargs.empty()) {
|
||||
// This is not ideal
|
||||
TORCH_CHECK(
|
||||
static_module_.schema(),
|
||||
"Schema is not available. Consider creating the Static Runtime "
|
||||
"with StaticModule(const torch::jit::Module& m) instead.");
|
||||
std::vector<c10::IValue> stack;
|
||||
stack.reserve(inputs_.size());
|
||||
if (static_module_.first_input_is_self()) {
|
||||
stack.emplace_back(static_module_.module()._ivalue());
|
||||
}
|
||||
stack.insert(stack.end(), args.begin(), args.end());
|
||||
|
||||
static_module_.schema()->checkAndNormalizeInputs(stack, kwargs);
|
||||
DCHECK_EQ(inputs_.size(), stack.size());
|
||||
for (const auto i : c10::irange(stack.size())) {
|
||||
Input(i) = std::move(stack[i]);
|
||||
}
|
||||
} else {
|
||||
if (static_module_.first_input_is_self()) {
|
||||
Input(0) = static_module_.module()._ivalue();
|
||||
DCHECK_EQ(inputs_.size(), args.size() + 1);
|
||||
for (const auto i : c10::irange(args.size())) {
|
||||
Input(i + 1) = args[i];
|
||||
}
|
||||
} else {
|
||||
DCHECK_EQ(inputs_.size(), args.size());
|
||||
for (const auto i : c10::irange(args.size())) {
|
||||
Input(i) = args[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c10::IValue StaticRuntime::operator()(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs) {
|
||||
|
|
@ -701,27 +760,13 @@ c10::IValue StaticRuntime::operator()(
|
|||
planner_->allocate();
|
||||
}
|
||||
|
||||
if (!kwargs.empty()) {
|
||||
// This is not ideal
|
||||
TORCH_CHECK(
|
||||
static_module_.schema(),
|
||||
"Schema is not available. Consider creating the Static Runtime "
|
||||
"with StaticModule(const torch::jit::Module& m) instead.");
|
||||
std::vector<c10::IValue> s = args;
|
||||
static_module_.schema()->checkAndNormalizeInputs(s, kwargs);
|
||||
for (const auto i : c10::irange(s.size())) {
|
||||
Input(i) = std::move(s[i]);
|
||||
}
|
||||
} else {
|
||||
for (const auto i : c10::irange(args.size())) {
|
||||
Input(i) = args[i];
|
||||
}
|
||||
}
|
||||
set_inputs(args, kwargs);
|
||||
|
||||
// NB: before optimizing the order of execution, ensure that the
|
||||
// memory optimization pass (LivenessMap) is
|
||||
// aware of the new order!
|
||||
for (auto& n : nodes_) {
|
||||
// LOG(INFO) << "Running node: " << PrintNode(n.node());
|
||||
n.run();
|
||||
}
|
||||
|
||||
|
|
@ -739,9 +784,7 @@ c10::IValue StaticRuntime::operator()(
|
|||
}
|
||||
planner_->deallocate();
|
||||
// clean up owning refs of input tensors
|
||||
for (IValue& ival : inputs_) {
|
||||
ival = IValue();
|
||||
}
|
||||
clean_up_input_ivalues();
|
||||
}
|
||||
|
||||
// no need to keep references of outputs in static runtime anymore
|
||||
|
|
@ -829,6 +872,10 @@ void StaticRuntime::benchmark(
|
|||
<< "%)" << std::endl;
|
||||
}
|
||||
check_for_memory_leak();
|
||||
|
||||
#ifndef NDEBUG
|
||||
display_nodes(args, kwargs);
|
||||
#endif
|
||||
}
|
||||
|
||||
float StaticRuntime::benchmark_model(
|
||||
|
|
@ -906,16 +953,36 @@ void display_pnode_info(const ProcessedNode& pnode) {
|
|||
}
|
||||
}
|
||||
|
||||
void StaticRuntime::display_nodes(const std::vector<c10::IValue>& args) {
|
||||
void StaticRuntime::display_nodes(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs) {
|
||||
c10::InferenceMode mode;
|
||||
std::vector<IValue> stack(args);
|
||||
for (size_t i = 0; i < stack.size(); i++) {
|
||||
Input(i) = stack[i];
|
||||
if (planner_) {
|
||||
planner_->allocate();
|
||||
}
|
||||
set_inputs(args, kwargs);
|
||||
|
||||
for (auto& node : nodes_) {
|
||||
node.run();
|
||||
display_pnode_info(node);
|
||||
}
|
||||
|
||||
if (static_module_.opts().cleanup_activations) {
|
||||
// MemoryPlanner is created after the first invocation of `run()`. This is
|
||||
// done intentionally because MemoryPlanner uses `Tensor` sizes of the
|
||||
// previous `run()` for memory planning of subsequent runs
|
||||
if (!planner_) {
|
||||
planner_ = std::make_unique<MemoryPlanner>(
|
||||
this,
|
||||
static_module_.values_share_same_storage(),
|
||||
static_module_.external_values(),
|
||||
static_module_.opts().enable_out_variant,
|
||||
static_module_.opts().optimize_graph_output_memory);
|
||||
}
|
||||
planner_->deallocate();
|
||||
// clean up owning refs of input tensors
|
||||
clean_up_input_ivalues();
|
||||
}
|
||||
}
|
||||
|
||||
StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
||||
|
|
@ -934,18 +1001,9 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
|||
|
||||
// setup time
|
||||
caffe2::Timer timer;
|
||||
std::vector<IValue> stack(args);
|
||||
if (!kwargs.empty()) {
|
||||
// This is not ideal
|
||||
TORCH_CHECK(
|
||||
static_module_.schema(),
|
||||
"Schema is not available. Consider creating the Static Runtime "
|
||||
"with StaticModule(const torch::jit::Module& m) instead.");
|
||||
static_module_.schema()->checkAndNormalizeInputs(stack, kwargs);
|
||||
}
|
||||
for (const auto i : c10::irange(stack.size())) {
|
||||
Input(i) = stack[i];
|
||||
}
|
||||
|
||||
set_inputs(args, kwargs);
|
||||
|
||||
results.setup_time = timer.MilliSeconds();
|
||||
|
||||
// warmup runs
|
||||
|
|
@ -957,9 +1015,9 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
|||
// main runs
|
||||
for (const auto k : c10::irange(main_runs)) {
|
||||
(void)k; // Suppress unused variable warning
|
||||
for (const auto i : c10::irange(stack.size())) {
|
||||
Input(i) = stack[i];
|
||||
}
|
||||
|
||||
set_inputs(args, kwargs);
|
||||
|
||||
timer.Start();
|
||||
if (planner_) {
|
||||
planner_->allocate();
|
||||
|
|
@ -985,9 +1043,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
|||
}
|
||||
planner_->deallocate();
|
||||
// clean up owning refs of input tensors
|
||||
for (IValue& ival : inputs_) {
|
||||
ival = IValue();
|
||||
}
|
||||
clean_up_input_ivalues();
|
||||
}
|
||||
millis = timer.MilliSeconds();
|
||||
results.memory_dealloc_time += millis;
|
||||
|
|
@ -1283,16 +1339,11 @@ ProcessedNode::ProcessedNode(
|
|||
VLOG(1) << "Switch to out variant for node: " << PrintNode(node);
|
||||
return;
|
||||
}
|
||||
if (!fn_ && mayRunNatively(node)) {
|
||||
native_fn_ = getNativeOperation(node);
|
||||
if (native_fn_) {
|
||||
if (!fn_ && (native_fn_ = getNativeOperation(node))) {
|
||||
VLOG(1) << "Switch to native impl for node: " << PrintNode(node);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (node->kind() != prim::ListConstruct &&
|
||||
node->kind() != prim::TupleConstruct &&
|
||||
node->kind() != prim::DictConstruct && node->kind() != prim::ListUnpack) {
|
||||
{
|
||||
const Operator& op = node->getOperator();
|
||||
TORCH_CHECK(op.hasOperation());
|
||||
op_ = op.getOperation(node);
|
||||
|
|
|
|||
|
|
@ -92,9 +92,8 @@ class TORCH_API StaticModule {
|
|||
|
||||
private:
|
||||
explicit StaticModule(
|
||||
std::pair<
|
||||
std::shared_ptr<torch::jit::Graph>,
|
||||
c10::optional<c10::FunctionSchema>> graph_and_schema,
|
||||
std::pair<std::shared_ptr<torch::jit::Graph>, std::shared_ptr<Module>>
|
||||
graph_and_module,
|
||||
const StaticModuleOptions& opts);
|
||||
|
||||
// for <kind, idx>
|
||||
|
|
@ -116,6 +115,10 @@ class TORCH_API StaticModule {
|
|||
return *graph_;
|
||||
}
|
||||
|
||||
const Module& module() const {
|
||||
return *module_;
|
||||
}
|
||||
|
||||
const StaticModuleOptions& opts() const;
|
||||
size_t num_inputs() const;
|
||||
size_t num_outputs() const;
|
||||
|
|
@ -149,11 +152,17 @@ class TORCH_API StaticModule {
|
|||
return external_values_;
|
||||
}
|
||||
|
||||
bool first_input_is_self() const {
|
||||
return first_input_is_self_;
|
||||
}
|
||||
|
||||
StaticRuntime& runtime();
|
||||
|
||||
private:
|
||||
StaticModuleOptions opts_;
|
||||
bool first_input_is_self_{false};
|
||||
std::shared_ptr<torch::jit::Graph> graph_;
|
||||
std::shared_ptr<torch::jit::Module> module_;
|
||||
c10::optional<c10::FunctionSchema> schema_;
|
||||
std::unique_ptr<StaticRuntime> cached_runtime_;
|
||||
|
||||
|
|
@ -188,7 +197,9 @@ class TORCH_API StaticRuntime {
|
|||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs);
|
||||
|
||||
void display_nodes(const std::vector<c10::IValue>& args);
|
||||
void display_nodes(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs);
|
||||
|
||||
void benchmark(
|
||||
const std::vector<c10::IValue>& args,
|
||||
|
|
@ -254,6 +265,18 @@ class TORCH_API StaticRuntime {
|
|||
void check_for_memory_leak(bool output_returned = true);
|
||||
|
||||
private:
|
||||
// helper method for copying input args/kwargs into inputs_
|
||||
void set_inputs(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs);
|
||||
|
||||
// clean up owning refs of input IValues
|
||||
void clean_up_input_ivalues() {
|
||||
for (IValue& ival : inputs_) {
|
||||
ival = IValue();
|
||||
}
|
||||
}
|
||||
|
||||
// Memory planning is only enabled if sm->opts().cleanup_activations is true.
|
||||
// Otherwise, the memory used by activations is cached inside the static
|
||||
// runtime.
|
||||
|
|
|
|||
|
|
@ -189,28 +189,6 @@ std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// TODO: expand to include all view producing ops, mostly in
|
||||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
||||
bool mayRunNatively(Node* n) {
|
||||
// In alphabetical order
|
||||
const static std::unordered_set<std::string> native_nodes{
|
||||
"aten::flatten",
|
||||
"aten::reshape",
|
||||
"aten::slice",
|
||||
"aten::transpose",
|
||||
"aten::to",
|
||||
"prim::ListConstruct",
|
||||
"prim::ListUnpack",
|
||||
"prim::TupleConstruct",
|
||||
"prim::DictConstruct",
|
||||
"aten::__getitem__"};
|
||||
auto str = std::string(n->kind().toQualString());
|
||||
if (!native_nodes.count(str)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Expensive check, use sparingly.
|
||||
// This is needed to make sure that we only switch to out variants for the
|
||||
// supported overloads, which is checked in the `Generate` step in
|
||||
|
|
@ -1302,6 +1280,24 @@ std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
|
|||
p_node->Output(0) = in0_t.clone();
|
||||
}
|
||||
};
|
||||
} else if (n->kind() == prim::GetAttr) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
auto module = p_node->Input(0).toObject();
|
||||
Node* node = p_node->node();
|
||||
const auto type = node->input()->type()->expect<ClassType>();
|
||||
const auto& field = node->s(attr::name);
|
||||
const auto slot = type->getAttributeSlot(field);
|
||||
p_node->Output(0) = module->getSlot(slot);
|
||||
};
|
||||
} else if (n->kind() == prim::SetAttr) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
auto module = p_node->Input(0).toObject();
|
||||
Node* node = p_node->node();
|
||||
const auto type = node->inputs()[0]->type()->expect<ClassType>();
|
||||
const auto& field = node->s(attr::name);
|
||||
const auto slot = type->getAttributeSlot(field);
|
||||
module->setSlot(slot, p_node->Input(1));
|
||||
};
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -124,7 +124,6 @@ bool isOptimizableContainerType(Node* n);
|
|||
|
||||
std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n);
|
||||
|
||||
bool mayRunNatively(Node* n);
|
||||
std::function<void(ProcessedNode*)> getNativeOperation(Node* n);
|
||||
|
||||
inline std::string PrintNode(const Node* node) {
|
||||
|
|
|
|||
|
|
@ -366,9 +366,6 @@ void FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph>& graph) {
|
|||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(static_runtime, m) {
|
||||
m.def("static_runtime::pure_inputs() -> Tensor", []() -> at::Tensor {
|
||||
return at::randn({1});
|
||||
});
|
||||
m.def("static_runtime::permute_copy(Tensor self, int[] dims) -> Tensor");
|
||||
m.def(
|
||||
"static_runtime::reshape_copy(Tensor(a) self, int[] shape) -> Tensor(a)");
|
||||
|
|
@ -386,24 +383,10 @@ bool HasInplaceOp(std::shared_ptr<Graph>& graph, const AliasDb& alias_db) {
|
|||
return HasInplaceOp(graph->block(), alias_db);
|
||||
}
|
||||
|
||||
void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
auto* fake_input =
|
||||
graph->insert(Symbol::fromQualString("static_runtime::pure_inputs"), {});
|
||||
fake_input->node()->moveBefore(*graph->nodes().begin());
|
||||
|
||||
std::vector<std::pair<Value*, Use>> old_inputs;
|
||||
for (auto* input : graph->inputs()) {
|
||||
for (const auto& use : input->uses()) {
|
||||
old_inputs.emplace_back(std::make_pair(input, use));
|
||||
}
|
||||
input->replaceAllUsesWith(fake_input);
|
||||
}
|
||||
|
||||
void ReplaceWithCopy(
|
||||
std::shared_ptr<torch::jit::Graph>& graph,
|
||||
bool outputs_are_immutable) {
|
||||
AliasDb db(graph);
|
||||
for (const auto& p : old_inputs) {
|
||||
p.second.user->replaceInput(p.second.offset, p.first);
|
||||
}
|
||||
fake_input->node()->destroy();
|
||||
|
||||
const std::map<c10::Symbol, c10::Symbol> supported = {
|
||||
#ifdef FBCODE_CAFFE2
|
||||
|
|
@ -474,7 +457,7 @@ void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph) {
|
|||
}
|
||||
|
||||
auto* out = n->output();
|
||||
if (db.mayContainAlias({out}, graph->outputs())) {
|
||||
if (!outputs_are_immutable && db.mayContainAlias({out}, graph->outputs())) {
|
||||
continue;
|
||||
}
|
||||
auto* new_node = graph->create(new_symbol, n->outputs().size());
|
||||
|
|
|
|||
|
|
@ -7,7 +7,11 @@ TORCH_API void FuseInferenceOpsForSparseNN(
|
|||
std::shared_ptr<torch::jit::Graph>& graph);
|
||||
TORCH_API void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph);
|
||||
|
||||
TORCH_API void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph);
|
||||
// If outputs_are_immutable is set to false, don't replace the view ops that
|
||||
// produce aliases of graph outputs with the copy version.
|
||||
TORCH_API void ReplaceWithCopy(
|
||||
std::shared_ptr<torch::jit::Graph>& graph,
|
||||
bool outputs_are_immutable = true);
|
||||
|
||||
TORCH_API bool HasInplaceOp(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user