From ccd097706005b0636699b0b6bbd016c0d7cd72c5 Mon Sep 17 00:00:00 2001 From: Hao Lu Date: Sat, 10 Jul 2021 14:04:48 -0700 Subject: [PATCH] [Static Runtime] Support prim::GetAttr/SetAttr (#61505) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61505 The handling of `self` in static runtime was previously incorrect. This diff fixed that issue, since self is essential to prim::GetAttr/SetAttr. After all, most of the time we're getting and setting attributes from self, the torch script module. Reviewed By: ajyu Differential Revision: D29350173 fbshipit-source-id: 6e62add4cda517ef8cd6c315d4cb0595e7d531fb --- test/test_static_runtime.py | 94 ++++++++++- torch/csrc/jit/runtime/static/impl.cpp | 193 ++++++++++++++--------- torch/csrc/jit/runtime/static/impl.h | 31 +++- torch/csrc/jit/runtime/static/ops.cpp | 40 +++-- torch/csrc/jit/runtime/static/ops.h | 1 - torch/csrc/jit/runtime/static/passes.cpp | 25 +-- torch/csrc/jit/runtime/static/passes.h | 6 +- 7 files changed, 262 insertions(+), 128 deletions(-) diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index 89cf184f76a..9b38a5a7e36 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -1,11 +1,11 @@ +import unittest +from typing import Dict, Optional + import numpy as np import torch -import unittest - from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests -from typing import Dict, Optional class StaticModule: def __init__(self, scripted): @@ -30,7 +30,9 @@ class StaticModule: ) -def linear_shim(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def linear_shim( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: output = input.matmul(weight.t()) if bias is not None: output += bias @@ -107,7 +109,8 @@ def trivial_graph(a, b, c): s = torch.tensor([[3, 3], [3, 3]]) return a + b * c + s -def loop_graph(a, b, iters : int): + +def loop_graph(a, b, iters: int): c = a + b * 2 for i in range(iters): c = c + b @@ -115,14 +118,50 @@ def loop_graph(a, b, iters : int): c -= a return c -def output_graph(a, b, c, iters : int): + +def output_graph(a, b, c, iters: int): s = torch.tensor([[3, 3], [3, 3]]) k = a + b * c + s - d : Dict[int, torch.Tensor] = {} + d: Dict[int, torch.Tensor] = {} for i in range(iters): d[i] = k + i return d + +class SubModule(nn.Module): + def __init__(self): + super(SubModule, self).__init__() + self.a = 11 + self.b = 2 + + def forward(self, x): + return self.a + self.b + x + + +class SubModule2(nn.Module): + def __init__(self): + super(SubModule2, self).__init__() + self.a = 12 + self.b = 2 + + def forward(self, x): + self.b = 30 + return self.a + self.b + x + + +class TestModule(nn.Module): + def __init__(self): + super(TestModule, self).__init__() + self.sub1 = SubModule() + self.sub2 = SubModule2() + self.a = 3 + self.b = 4 + + def forward(self, x): + self.b = 20 + return self.sub1(x) + self.a + self.b + self.sub2(x) + + class TestStaticModule(TestCase): def test_multihead_attention_layer(self): HID_DIM = 256 @@ -220,6 +259,46 @@ class TestStaticModule(TestCase): o_test = tg_a(s)[0] torch.testing.assert_allclose(o_ref, o_test) + def test_attr(self): + """ + TorchScript IR of TestModule() after freezing: + graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule, + %x.1 : Tensor): + %18 : int = prim::Constant[value=30]() + %30 : int = prim::Constant[value=13]() + %3 : int = prim::Constant[value=20]() + %2 : int = prim::Constant[value=1]() + %self.sub2.a : int = prim::Constant[value=12]() + %self.a : int = prim::Constant[value=3]() + = prim::SetAttr[name="b"](%self, %3) + %17 : Tensor = aten::add(%x.1, %30, %2) + %7 : Tensor = aten::add(%17, %self.a, %2) + %b.1 : int = prim::GetAttr[name="b"](%self) + %9 : Tensor = aten::add(%7, %b.1, %2) + %sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self) + = prim::SetAttr[name="b"](%sub2, %18) + %b : int = prim::GetAttr[name="b"](%sub2) + %22 : int = aten::add(%self.sub2.a, %b) + %23 : Tensor = aten::add(%x.1, %22, %2) + %12 : Tensor = aten::add(%9, %23, %2) + return (%12) + """ + # test prim::SetAttr and prim::GetAttr impl in Static Runtime + m = TestModule() + + m.eval() + input = torch.randn(2, 2) + output_s = m.forward(input) + + ms = torch.jit.script(m) + sm = StaticModule(ms) + output_sm = sm(input)[0] + torch.testing.assert_allclose(output_s, output_sm) + sm.benchmark([input], {}, 2, 2) + sm.benchmark_individual_ops([input], {}, 2, 2) + sm.benchmark([], {"x": input}, 2, 2) + sm.benchmark_individual_ops([], {"x": input}, 2, 2) + @unittest.skip("Temporarily disabled") def test_fusion_trivial_graph(self): s = torch.full((2, 2), 2) @@ -281,6 +360,5 @@ class TestStaticModule(TestCase): torch.testing.assert_allclose(o_ref[i], o_test[i]) - if __name__ == "__main__": run_tests() diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index fd915233826..7924e1b957a 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -45,20 +45,29 @@ void OptimizeGraph( ConstantPropagation(graph); } -void CheckGraphEligibility(const std::shared_ptr& graph) { - for (auto n : graph->nodes()) { - if (n->kind() == c10::Symbol::fromQualString("prim::GetAttr")) { - throw std::runtime_error("Cannot accelerate unfrozen graphs"); +bool CheckGraphEligibility(const std::shared_ptr& graph) { + // check for sub-blocks + bool can_support = true; + for (auto* node : graph->block()->nodes()) { + for (Block* sub_block : node->blocks()) { + VLOG(1) << "Found nested sub-blocks in graph at node: " + << PrintNode(node); + can_support = false; } } + + return can_support; } // remove unused input 0 from graph -void RemoveSelfFromGraphInput(std::shared_ptr& graph) { +bool RemoveSelfFromGraphInput(std::shared_ptr& graph) { if (graph->inputs().at(0)->type()->is_module()) { - TORCH_CHECK(!graph->inputs().at(0)->hasUses()); + if (graph->inputs().at(0)->hasUses()) { + return false; + } graph->eraseInput(0); } + return true; } // remove "self" from function schema @@ -443,12 +452,12 @@ GenerateSameStorageValues( void PrepareGraphForStaticModule( std::shared_ptr graph, const StaticModuleOptions& opts) { - CheckGraphEligibility(graph); + // TODO: call CheckGraphEligibility before trying to enable static runtime + TORCH_CHECK(CheckGraphEligibility(graph)); OptimizeGraph(graph, opts); - RemoveSelfFromGraphInput(graph); } -std::pair, c10::optional> +std::pair, std::shared_ptr> PrepareForStaticModule( const torch::jit::Module& m, const StaticModuleOptions& opts) { @@ -461,22 +470,23 @@ PrepareForStaticModule( auto module = m.copy(); module.eval(); - module = freeze_module(module); + auto module_ptr = std::make_shared(freeze_module(module)); - Method method = module.get_method("forward"); - auto graph = module.get_method("forward").graph(); + Method method = module_ptr->get_method("forward"); + auto graph = module_ptr->get_method("forward").graph(); + + // graph->dump(); PrepareGraphForStaticModule(graph, opts); - c10::FunctionSchema s = RemoveSelfFromSchema(method.function().getSchema()); - return std::make_pair(graph, s); + return std::make_pair(graph, module_ptr); } -std::pair, c10::optional> +std::pair, std::shared_ptr> PrepareForStaticModule( std::shared_ptr graph, const StaticModuleOptions& opts) { PrepareGraphForStaticModule(graph, opts); - return std::make_pair(graph, c10::nullopt); + return std::make_pair(graph, nullptr); } } // namespace @@ -492,13 +502,12 @@ StaticModule::StaticModule( : StaticModule(PrepareForStaticModule(m, opts), opts) {} StaticModule::StaticModule( - std::pair< - std::shared_ptr, - c10::optional> graph_and_schema, + std::pair, std::shared_ptr> + graph_and_module, const StaticModuleOptions& opts) : opts_(opts), - graph_(std::move(graph_and_schema.first)), - schema_(std::move(graph_and_schema.second)) { + graph_(std::move(graph_and_module.first)), + module_(std::move(graph_and_module.second)) { // check opt flags if (opts.optimize_graph_output_memory) { TORCH_CHECK( @@ -511,6 +520,18 @@ StaticModule::StaticModule( "When optimize_memory is true, enable_out_variant must be set to true"); } + // handle schema + if (module_) { + Method method = module_->get_method("forward"); + schema_ = method.function().getSchema(); + if (RemoveSelfFromGraphInput(graph_)) { + schema_ = RemoveSelfFromSchema(method.function().getSchema()); + } else { + first_input_is_self_ = true; + schema_ = method.function().getSchema(); + } + } + // map Value* to IValue (from inputs or prim::Constant) or null std::unordered_map value_to_ivalue; // map Value* to its SSA definition IR @@ -620,6 +641,7 @@ StaticRuntime::StaticRuntime(const StaticModule& sm) : static_module_(sm) { // NB: create unchanging std::vectors we can reference inputs_.resize(sm.num_inputs()); nodes_.resize(sm.nodes().size()); + for (const auto idx : c10::irange(sm.nodes().size())) { const auto& n_ref = sm.nodes()[idx]; nodes_[idx] = n_ref; // copy the node @@ -688,6 +710,43 @@ std::vector StaticRuntime::operator()( return out; } +void StaticRuntime::set_inputs( + const std::vector& args, + const std::unordered_map& kwargs) { + if (!kwargs.empty()) { + // This is not ideal + TORCH_CHECK( + static_module_.schema(), + "Schema is not available. Consider creating the Static Runtime " + "with StaticModule(const torch::jit::Module& m) instead."); + std::vector stack; + stack.reserve(inputs_.size()); + if (static_module_.first_input_is_self()) { + stack.emplace_back(static_module_.module()._ivalue()); + } + stack.insert(stack.end(), args.begin(), args.end()); + + static_module_.schema()->checkAndNormalizeInputs(stack, kwargs); + DCHECK_EQ(inputs_.size(), stack.size()); + for (const auto i : c10::irange(stack.size())) { + Input(i) = std::move(stack[i]); + } + } else { + if (static_module_.first_input_is_self()) { + Input(0) = static_module_.module()._ivalue(); + DCHECK_EQ(inputs_.size(), args.size() + 1); + for (const auto i : c10::irange(args.size())) { + Input(i + 1) = args[i]; + } + } else { + DCHECK_EQ(inputs_.size(), args.size()); + for (const auto i : c10::irange(args.size())) { + Input(i) = args[i]; + } + } + } +} + c10::IValue StaticRuntime::operator()( const std::vector& args, const std::unordered_map& kwargs) { @@ -701,27 +760,13 @@ c10::IValue StaticRuntime::operator()( planner_->allocate(); } - if (!kwargs.empty()) { - // This is not ideal - TORCH_CHECK( - static_module_.schema(), - "Schema is not available. Consider creating the Static Runtime " - "with StaticModule(const torch::jit::Module& m) instead."); - std::vector s = args; - static_module_.schema()->checkAndNormalizeInputs(s, kwargs); - for (const auto i : c10::irange(s.size())) { - Input(i) = std::move(s[i]); - } - } else { - for (const auto i : c10::irange(args.size())) { - Input(i) = args[i]; - } - } + set_inputs(args, kwargs); // NB: before optimizing the order of execution, ensure that the // memory optimization pass (LivenessMap) is // aware of the new order! for (auto& n : nodes_) { + // LOG(INFO) << "Running node: " << PrintNode(n.node()); n.run(); } @@ -739,9 +784,7 @@ c10::IValue StaticRuntime::operator()( } planner_->deallocate(); // clean up owning refs of input tensors - for (IValue& ival : inputs_) { - ival = IValue(); - } + clean_up_input_ivalues(); } // no need to keep references of outputs in static runtime anymore @@ -829,6 +872,10 @@ void StaticRuntime::benchmark( << "%)" << std::endl; } check_for_memory_leak(); + +#ifndef NDEBUG + display_nodes(args, kwargs); +#endif } float StaticRuntime::benchmark_model( @@ -906,16 +953,36 @@ void display_pnode_info(const ProcessedNode& pnode) { } } -void StaticRuntime::display_nodes(const std::vector& args) { +void StaticRuntime::display_nodes( + const std::vector& args, + const std::unordered_map& kwargs) { c10::InferenceMode mode; - std::vector stack(args); - for (size_t i = 0; i < stack.size(); i++) { - Input(i) = stack[i]; + if (planner_) { + planner_->allocate(); } + set_inputs(args, kwargs); + for (auto& node : nodes_) { node.run(); display_pnode_info(node); } + + if (static_module_.opts().cleanup_activations) { + // MemoryPlanner is created after the first invocation of `run()`. This is + // done intentionally because MemoryPlanner uses `Tensor` sizes of the + // previous `run()` for memory planning of subsequent runs + if (!planner_) { + planner_ = std::make_unique( + this, + static_module_.values_share_same_storage(), + static_module_.external_values(), + static_module_.opts().enable_out_variant, + static_module_.opts().optimize_graph_output_memory); + } + planner_->deallocate(); + // clean up owning refs of input tensors + clean_up_input_ivalues(); + } } StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( @@ -934,18 +1001,9 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( // setup time caffe2::Timer timer; - std::vector stack(args); - if (!kwargs.empty()) { - // This is not ideal - TORCH_CHECK( - static_module_.schema(), - "Schema is not available. Consider creating the Static Runtime " - "with StaticModule(const torch::jit::Module& m) instead."); - static_module_.schema()->checkAndNormalizeInputs(stack, kwargs); - } - for (const auto i : c10::irange(stack.size())) { - Input(i) = stack[i]; - } + + set_inputs(args, kwargs); + results.setup_time = timer.MilliSeconds(); // warmup runs @@ -957,9 +1015,9 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( // main runs for (const auto k : c10::irange(main_runs)) { (void)k; // Suppress unused variable warning - for (const auto i : c10::irange(stack.size())) { - Input(i) = stack[i]; - } + + set_inputs(args, kwargs); + timer.Start(); if (planner_) { planner_->allocate(); @@ -985,9 +1043,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( } planner_->deallocate(); // clean up owning refs of input tensors - for (IValue& ival : inputs_) { - ival = IValue(); - } + clean_up_input_ivalues(); } millis = timer.MilliSeconds(); results.memory_dealloc_time += millis; @@ -1283,16 +1339,11 @@ ProcessedNode::ProcessedNode( VLOG(1) << "Switch to out variant for node: " << PrintNode(node); return; } - if (!fn_ && mayRunNatively(node)) { - native_fn_ = getNativeOperation(node); - if (native_fn_) { - VLOG(1) << "Switch to native impl for node: " << PrintNode(node); - return; - } + if (!fn_ && (native_fn_ = getNativeOperation(node))) { + VLOG(1) << "Switch to native impl for node: " << PrintNode(node); + return; } - if (node->kind() != prim::ListConstruct && - node->kind() != prim::TupleConstruct && - node->kind() != prim::DictConstruct && node->kind() != prim::ListUnpack) { + { const Operator& op = node->getOperator(); TORCH_CHECK(op.hasOperation()); op_ = op.getOperation(node); diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 01cbcb7a83b..5d352724686 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -92,9 +92,8 @@ class TORCH_API StaticModule { private: explicit StaticModule( - std::pair< - std::shared_ptr, - c10::optional> graph_and_schema, + std::pair, std::shared_ptr> + graph_and_module, const StaticModuleOptions& opts); // for @@ -116,6 +115,10 @@ class TORCH_API StaticModule { return *graph_; } + const Module& module() const { + return *module_; + } + const StaticModuleOptions& opts() const; size_t num_inputs() const; size_t num_outputs() const; @@ -149,11 +152,17 @@ class TORCH_API StaticModule { return external_values_; } + bool first_input_is_self() const { + return first_input_is_self_; + } + StaticRuntime& runtime(); private: StaticModuleOptions opts_; + bool first_input_is_self_{false}; std::shared_ptr graph_; + std::shared_ptr module_; c10::optional schema_; std::unique_ptr cached_runtime_; @@ -188,7 +197,9 @@ class TORCH_API StaticRuntime { const std::vector& args, const std::unordered_map& kwargs); - void display_nodes(const std::vector& args); + void display_nodes( + const std::vector& args, + const std::unordered_map& kwargs); void benchmark( const std::vector& args, @@ -254,6 +265,18 @@ class TORCH_API StaticRuntime { void check_for_memory_leak(bool output_returned = true); private: + // helper method for copying input args/kwargs into inputs_ + void set_inputs( + const std::vector& args, + const std::unordered_map& kwargs); + + // clean up owning refs of input IValues + void clean_up_input_ivalues() { + for (IValue& ival : inputs_) { + ival = IValue(); + } + } + // Memory planning is only enabled if sm->opts().cleanup_activations is true. // Otherwise, the memory used by activations is cached inside the static // runtime. diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index aa045fd6cbf..8cc6cae9ba0 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -189,28 +189,6 @@ std::function getOutOfPlaceOperation(Node* n) { return nullptr; } -// TODO: expand to include all view producing ops, mostly in -// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp -bool mayRunNatively(Node* n) { - // In alphabetical order - const static std::unordered_set native_nodes{ - "aten::flatten", - "aten::reshape", - "aten::slice", - "aten::transpose", - "aten::to", - "prim::ListConstruct", - "prim::ListUnpack", - "prim::TupleConstruct", - "prim::DictConstruct", - "aten::__getitem__"}; - auto str = std::string(n->kind().toQualString()); - if (!native_nodes.count(str)) { - return false; - } - return true; -} - // Expensive check, use sparingly. // This is needed to make sure that we only switch to out variants for the // supported overloads, which is checked in the `Generate` step in @@ -1302,6 +1280,24 @@ std::function getNativeOperation(Node* n) { p_node->Output(0) = in0_t.clone(); } }; + } else if (n->kind() == prim::GetAttr) { + return [](ProcessedNode* p_node) { + auto module = p_node->Input(0).toObject(); + Node* node = p_node->node(); + const auto type = node->input()->type()->expect(); + const auto& field = node->s(attr::name); + const auto slot = type->getAttributeSlot(field); + p_node->Output(0) = module->getSlot(slot); + }; + } else if (n->kind() == prim::SetAttr) { + return [](ProcessedNode* p_node) { + auto module = p_node->Input(0).toObject(); + Node* node = p_node->node(); + const auto type = node->inputs()[0]->type()->expect(); + const auto& field = node->s(attr::name); + const auto slot = type->getAttributeSlot(field); + module->setSlot(slot, p_node->Input(1)); + }; } return nullptr; } diff --git a/torch/csrc/jit/runtime/static/ops.h b/torch/csrc/jit/runtime/static/ops.h index d35df5f8065..021cd21fd72 100644 --- a/torch/csrc/jit/runtime/static/ops.h +++ b/torch/csrc/jit/runtime/static/ops.h @@ -124,7 +124,6 @@ bool isOptimizableContainerType(Node* n); std::function getOutOfPlaceOperation(Node* n); -bool mayRunNatively(Node* n); std::function getNativeOperation(Node* n); inline std::string PrintNode(const Node* node) { diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index c4cdea0e945..940a2aa2e71 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -366,9 +366,6 @@ void FuseInferenceOpsForSparseNN(std::shared_ptr& graph) { } TORCH_LIBRARY_FRAGMENT(static_runtime, m) { - m.def("static_runtime::pure_inputs() -> Tensor", []() -> at::Tensor { - return at::randn({1}); - }); m.def("static_runtime::permute_copy(Tensor self, int[] dims) -> Tensor"); m.def( "static_runtime::reshape_copy(Tensor(a) self, int[] shape) -> Tensor(a)"); @@ -386,24 +383,10 @@ bool HasInplaceOp(std::shared_ptr& graph, const AliasDb& alias_db) { return HasInplaceOp(graph->block(), alias_db); } -void ReplaceWithCopy(std::shared_ptr& graph) { - auto* fake_input = - graph->insert(Symbol::fromQualString("static_runtime::pure_inputs"), {}); - fake_input->node()->moveBefore(*graph->nodes().begin()); - - std::vector> old_inputs; - for (auto* input : graph->inputs()) { - for (const auto& use : input->uses()) { - old_inputs.emplace_back(std::make_pair(input, use)); - } - input->replaceAllUsesWith(fake_input); - } - +void ReplaceWithCopy( + std::shared_ptr& graph, + bool outputs_are_immutable) { AliasDb db(graph); - for (const auto& p : old_inputs) { - p.second.user->replaceInput(p.second.offset, p.first); - } - fake_input->node()->destroy(); const std::map supported = { #ifdef FBCODE_CAFFE2 @@ -474,7 +457,7 @@ void ReplaceWithCopy(std::shared_ptr& graph) { } auto* out = n->output(); - if (db.mayContainAlias({out}, graph->outputs())) { + if (!outputs_are_immutable && db.mayContainAlias({out}, graph->outputs())) { continue; } auto* new_node = graph->create(new_symbol, n->outputs().size()); diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h index 2becd861d47..11ab4bdc7c4 100644 --- a/torch/csrc/jit/runtime/static/passes.h +++ b/torch/csrc/jit/runtime/static/passes.h @@ -7,7 +7,11 @@ TORCH_API void FuseInferenceOpsForSparseNN( std::shared_ptr& graph); TORCH_API void FuseListUnpack(std::shared_ptr& graph); -TORCH_API void ReplaceWithCopy(std::shared_ptr& graph); +// If outputs_are_immutable is set to false, don't replace the view ops that +// produce aliases of graph outputs with the copy version. +TORCH_API void ReplaceWithCopy( + std::shared_ptr& graph, + bool outputs_are_immutable = true); TORCH_API bool HasInplaceOp( std::shared_ptr& graph,