From 66f1485424bfc503df7ca5e81cbe4844d0fee4e1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Aug 2017 10:29:30 -0700 Subject: [PATCH] Improve performance of compilation by ~8% by speeding up the hlo rematerialization pass. Changes: . Wrap each HloInstruction* inside an Item structure that keeps associated data. This allows us to get rid of a bunch of hash tables indexed by HloInstruction*. * Switch to an intrusive linked list (instead of std::list) so that we can avoid a hash table that maps to std::list::iterator. * Use inlined vector in a few places. PiperOrigin-RevId: 163848365 --- .../xla/service/hlo_rematerialization.cc | 563 ++++++++++-------- 1 file changed, 307 insertions(+), 256 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index fd2f700029e..9f65f1b8512 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -47,6 +47,12 @@ namespace xla { namespace { +// Potential optimizations: +// . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue +// of candidates. +// . Cache IsRematerializable in Item? Only correct if control +// predecessors and successors don't change. + // Returns true if the given instruction is rematerializable. bool IsRematerializable(const HloInstruction* instruction) { // Conservatively, don't rematerialize instruction with control @@ -79,126 +85,202 @@ bool IsRematerializable(const HloInstruction* instruction) { } } +// Type holding a unique identifier for each Buffer object. +using BufferId = int64; +using BufferIdList = tensorflow::gtl::InlinedVector; + +// We wrap HloInstruction* with an Item that holds auxiliary +// per-instruction state. +struct Item { + HloInstruction* instruction; + + // True once the instruction is marked as placed (when BeginInstruction + // has been called for this instruction). + bool placed = false; + + // To avoid an infinite loop rematerializing the same set of + // instructions ad infinitum, keep a blacklist of instructions + // which should not be rematerialized. + bool blacklisted = false; + + // The buffers defined by this instruction. + BufferIdList buffers_defined; + + // The buffers used by this instruction. + BufferIdList buffers_used; + + private: + friend class InstructionList; + + // Items are arranged in a doubly linked list. + Item* next; + Item* prev; + + // List is ordered by position, which can however be duplicated as + // new instructions are inserted. See InsertBeforeInstructions + // comment for details. + int64 position; +}; + +using ItemList = tensorflow::gtl::InlinedVector; + // Class which maintains an ordered list of instructions with fast insertion // before arbitrary elements. class InstructionList { public: explicit InstructionList(const std::vector& order) { int64 position = 0; + Item* last = nullptr; for (const HloInstruction* inst : order) { - instructions_.push_back(const_cast(inst)); - instruction_iterators_.insert({const_cast(inst), - std::next(instructions_.end(), -1)}); + // Add a new item to the linked list. + Item* item = new Item; + item->next = nullptr; + item->prev = last; + if (last == nullptr) { + first_ = item; + } else { + last->next = item; + } + last = item; + // Initially position numbers are uniquely assigned in order. Later as // instructions are added with InsertBefore* methods, some instructions // may have duplicate position numbers, but the values will be guaranteed // to be monotonically increasing through the list, and so is still useful // for quickly(-ish) determining the order of arbitrary instructions in // the list. - position_number_[inst] = position; - first_at_position_[position] = inst; + item->instruction = const_cast(inst); + item->position = position; position++; + + item_map_[inst] = item; } } - // Returns the list of instructions. - const std::list& instructions() const { - return instructions_; + ~InstructionList() { + for (Item* item = first_; item != nullptr;) { + Item* next = item->next; + delete item; + item = next; + } } - // Insert instruction 'to_insert' immediately before instruction 'before' in - // the list. - void InsertBefore(HloInstruction* to_insert, HloInstruction* before) { - VLOG(3) << "InsertBefore: " << to_insert->name() << " before " - << before->name(); - auto it = instruction_iterators_.find(before); - CHECK(it != instruction_iterators_.end()); - instruction_iterators_.insert( - {to_insert, instructions_.insert(it->second, to_insert)}); - // Assign the same position number to the newly added instruction as - // 'before'. This guarantees monotonicity of the position numbers, but not - // uniqueness. - int64 pos = position_number_.at(before); - position_number_[to_insert] = pos; - if (first_at_position_.at(pos) == before) { - first_at_position_[pos] = to_insert; - } + size_t size() const { return item_map_.size(); } + + // For ordered iteration over items. + // for (auto item = q.first(); item != nullptr; item = q.next(item)) {...} + Item* first() const { return first_; } + Item* next(Item* item) const { return item->next; } + + // Creates an Item for the given instruction, but doesn't add it to the list. + // (Use InsertBeforeInstructions to add the Item to the list.) + Item* CreateItem(HloInstruction* inst) { + Item* item = new Item; + item->instruction = inst; + CHECK(item_map_.insert({inst, item}).second) << "inserting inst twice"; + return item; + } + + // Return the Item corresponding to inst. + Item* GetItem(const HloInstruction* inst) const { + auto iter = item_map_.find(inst); + CHECK(iter != item_map_.end()) << "Did not find " << inst->name(); + return iter->second; } // Insert instruction 'to_insert' immediately before the earliest instruction // in 'before_instructions'. + // + // Each instruction gets a non-decreasing ordinal number. We use this to let + // InsertBeforeInstructions quickly insert an instruction before the earliest + // instruction in a set of instructions. If position_number_[a] < + // position_number_[b] then 'a' comes before 'b' in the list. If the position + // numbers are the same then nothing can be said about their order without + // examining the list. + // + // On object construction this ordinal is precisely the instruction's index + // in the list. Later, instructions inserted via InsertBefore receive + // duplicate values. However, monotonicity is preserved. void InsertBeforeInstructions( - HloInstruction* to_insert, - tensorflow::gtl::ArraySlice before_instructions) { - VLOG(3) << "InsertBeforeInstructions: " << to_insert->name() << " before {" - << tensorflow::str_util::Join( - before_instructions, ", ", - [](string* out, HloInstruction* inst) { - tensorflow::strings::StrAppend(out, inst->name()); - }) + Item* to_insert, tensorflow::gtl::ArraySlice before_instructions) { + VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name() + << " before {" + << tensorflow::str_util::Join(before_instructions, ", ", + [](string* out, Item* item) { + tensorflow::strings::StrAppend( + out, item->instruction->name()); + }) << "}"; // Find the minimal position number of any instruction in // 'before_instructions'. CHECK(!before_instructions.empty()); - int64 min_position_number = std::numeric_limits::max(); - for (const HloInstruction* instruction : before_instructions) { - min_position_number = - std::min(min_position_number, position_number_.at(instruction)); + Item* min_position_item = nullptr; + for (Item* item : before_instructions) { + if (min_position_item == nullptr || + item->position < min_position_item->position) { + min_position_item = item; + } } // Because more than one instruction in 'before_instructions' may have a // position number of 'min_position_number', find the first such instruction // with position number 'min_position_number'. - for (auto it = instruction_iterators_.at( - first_at_position_.at(min_position_number)); - it != instructions_.end() && - position_number_.at(*it) == min_position_number; - ++it) { - if (std::find(before_instructions.begin(), before_instructions.end(), - *it) != before_instructions.end()) { - return InsertBefore(to_insert, *it); - } + + // First find first instruction with the min position. + while (min_position_item->prev != nullptr && + min_position_item->position == min_position_item->prev->position) { + min_position_item = min_position_item->prev; } - LOG(FATAL) << "Expected to find instruction in before_instructions with " - "position number " - << min_position_number; + + // Now scan forwards until we find one of the before_instructions. + while (std::find(before_instructions.begin(), before_instructions.end(), + min_position_item) == before_instructions.end()) { + min_position_item = min_position_item->next; + } + return InsertBefore(to_insert, min_position_item); + } + + void Blacklist(const HloInstruction* inst) { + GetItem(inst)->blacklisted = true; } private: - // List of instructions. - std::list instructions_; + // Insert instruction 'item' immediately before 'before' in the list. + void InsertBefore(Item* item, Item* before) { + VLOG(3) << "InsertBefore: " << item->instruction->name() << " before " + << before->instruction->name(); + // Insert new item into linked list. + item->prev = before->prev; + item->next = before; + before->prev = item; + if (item->prev != nullptr) { + item->prev->next = item; + } else { + first_ = item; + } - // Iterators for each instruction in the list. - tensorflow::gtl::FlatMap::iterator> - instruction_iterators_; + // Assign the same position number to the newly added instruction as + // 'before'. This guarantees monotonicity of the position numbers, but not + // uniqueness. + item->position = before->position; + } - // A number assigned to each instruction which increases monotonically through - // 'instructions_'. Used to facilitate fast insertion of an instruction before - // the earliest instruction in a set of instructions - // (InsertBeforeInstructions) by enabling fast-ish ordering queries between - // instructions. If position_number_[a] < position_number_[b] then 'a' comes - // before 'b' in the list. If the position numbers are the same then nothing - // can be said about their order without examining the list. - // - // On object construction this value is precisely the instruction's ordinal - // position in the list. Instructions inserted via InsertBefore receive - // duplicate values. However, monotonicity is preserved. - tensorflow::gtl::FlatMap position_number_; + Item* first_; - // The first instruction in the list assigned a particular position number. - tensorflow::gtl::FlatMap first_at_position_; + // Item for each instruction. + tensorflow::gtl::FlatMap item_map_; }; -// Return the HloInstructions which use the given LogicalBuffer. Sets +// Return the items which use the given LogicalBuffer. Sets // has_indirect_users to whether any of the uses is indirect. A use is indirect // if the instruction defining logical_buffer is not an operand of the use. This // can happen via buffer aliasing (eg, tuples). -std::vector GetUsers( - const LogicalBuffer* logical_buffer, - const TuplePointsToAnalysis& points_to_analysis, bool* has_indirect_users) { - std::vector users; +ItemList GetUsers(const InstructionList& instruction_list, + const LogicalBuffer* logical_buffer, + const TuplePointsToAnalysis& points_to_analysis, + bool* has_indirect_users) { + ItemList users; // To identify uses iterate through all HloInstruction users of the // BufferAliases of the logical buffer. *has_indirect_users = false; @@ -219,8 +301,9 @@ std::vector GetUsers( } // A buffer may be used by the instruction via more than one alias. For // example, a buffer which appears in more than one element of a tuple. - if (std::find(users.begin(), users.end(), user) == users.end()) { - users.push_back(user); + Item* user_item = instruction_list.GetItem(user); + if (std::find(users.begin(), users.end(), user_item) == users.end()) { + users.push_back(user_item); } } } @@ -244,7 +327,7 @@ class MemoryUsageTracker { // EndInstruction) to accurately model memory usage. At BeginInstruction the // memory for the output value(s) of the current instruction is allocated. At // EndInstruction memory for dead operand(s) is freed. - Status BeginInstruction(const HloInstruction* instruction); + Status BeginInstruction(Item* item); // Finishes the placement of the current instruction. This frees any dead // operands or dead result of the instruction. This must be called after @@ -253,40 +336,31 @@ class MemoryUsageTracker { // Returns the number of bytes that the current memory usage will be reduced // if the given instruction is rematerialized. - int64 MemoryReducedIfRematerialized(const HloInstruction* instruction) const; + int64 MemoryReducedIfRematerialized(Item* item) const; // Adjusts memory usage to account for the rematerialization of - // original_instruction for all remaining unplaced uses. The rematerialization - // is remat_instruction. This method should be called after the HLO graph has + // original_item for all remaining unplaced uses. The rematerialization + // is remat_item. This method should be called after the HLO graph has // been transformed (rematerialization instruction created and connected to // uses). - Status AddRematerializedInstruction(HloInstruction* original_instruction, - HloInstruction* remat_instruction); + Status AddRematerializedInstruction(Item* original_item, Item* remat_item); // Returns whether the given instruction has been placed (BeginInstruction // has been called with 'instruction' as the argument). bool IsPlaced(const HloInstruction* instruction) const { - return ContainsKey(placed_instructions_, instruction); + return instruction_list_.GetItem(instruction)->placed; } // Returns the current memory usage. This is the sum of sizes of all live // values. int64 memory_usage() const { return memory_usage_; } - // Returns the current instruction being placed. - const HloInstruction* in_progress_instruction() const { - return in_progress_instruction_; - } - // Check invariants of the data structure. This is expensive to call. bool Check() const; string ToString() const; private: - // Type holding a unique identifier for each Buffer object. - using BufferId = int64; - // A Buffer represents a single LogicalBuffer in the computation including // various metadata useful for tracking liveness of the value. A LogicalBuffer // is not used directly because the HLO graph is transformed and @@ -298,7 +372,7 @@ class MemoryUsageTracker { const BufferId id; // The instruction which defines this buffer. - const HloInstruction* defining_instruction; + Item* defining_instruction; // The materialized size of the buffer in bytes. const int64 size; @@ -312,16 +386,17 @@ class MemoryUsageTracker { bool has_indirect_uses; // The instructions which use this buffer. - std::vector users; + ItemList users; // The number of users (HloInstructions) of this buffer which have not yet // been placed in the sequence. int64 unfinished_user_count; string ToString() const { - return tensorflow::strings::StrCat("Buffer ", id, " (defined by ", - defining_instruction->name(), - ", size ", size, " bytes)"); + return tensorflow::strings::StrCat( + "Buffer ", id, " (defined by ", + defining_instruction->instruction->name(), ", size ", size, + " bytes)"); } }; @@ -333,25 +408,24 @@ class MemoryUsageTracker { const HloRematerialization::ShapeSizeFunction& size_function, bool live_out) { bool has_indirect_uses = false; - std::vector users = - GetUsers(logical_buffer, points_to_analysis, &has_indirect_uses); - return NewBuffer(logical_buffer->instruction(), + ItemList users = GetUsers(instruction_list_, logical_buffer, + points_to_analysis, &has_indirect_uses); + return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()), size_function(logical_buffer->shape()), std::move(users), live_out, has_indirect_uses); } // Create a new buffer representing a rematerialization of given buffer for // the given uses. - Buffer& RematerializeBuffer( - const Buffer& original_buffer, const HloInstruction* remat_instruction, - std::vector&& rematerialized_uses) { - CHECK(IsPlaced(original_buffer.defining_instruction)); + Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item, + ItemList&& rematerialized_uses) { + CHECK(original_buffer.defining_instruction->placed); CHECK(!original_buffer.has_indirect_uses); CHECK(!original_buffer.live_out); - for (const HloInstruction* use : rematerialized_uses) { - CHECK(!IsPlaced(use)); + for (Item* use : rematerialized_uses) { + CHECK(!use->placed); } - return NewBuffer(remat_instruction, original_buffer.size, + return NewBuffer(remat_item, original_buffer.size, std::move(rematerialized_uses), /*live_out=*/false, /*has_indirect_uses=*/false); } @@ -362,7 +436,7 @@ class MemoryUsageTracker { // different computation. int64 AllocatedSize(BufferId buffer_id) const { const Buffer& buffer = buffers_.at(buffer_id); - HloOpcode def_opcode = buffer.defining_instruction->opcode(); + HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode(); if (buffer.live_out || def_opcode == HloOpcode::kParameter) { return 0; } else { @@ -372,18 +446,17 @@ class MemoryUsageTracker { // Returns true if BeginInstruction and EndInstruction has been called for the // given instruction. - bool IsFinished(const HloInstruction* instruction) const { - return IsPlaced(instruction) && instruction != in_progress_instruction_; + bool IsFinished(Item* item) const { + return item->placed && item != in_progress_item_; } // Returns whether the given buffer is being used by the in-progress // instruction. bool IsInUse(BufferId buffer_id) const { - if (in_progress_instruction_ == nullptr) { + if (in_progress_item_ == nullptr) { return false; } - const std::vector& in_progress_uses = - buffers_used_by_instruction_.at(in_progress_instruction_); + const BufferIdList& in_progress_uses = in_progress_item_->buffers_used; return std::find(in_progress_uses.begin(), in_progress_uses.end(), buffer_id) != in_progress_uses.end(); } @@ -392,14 +465,13 @@ class MemoryUsageTracker { // point. bool IsCurrentlyLive(BufferId buffer_id) const { const Buffer& buffer = buffers_[buffer_id]; - return (IsPlaced(buffer.defining_instruction) && + return (buffer.defining_instruction->placed && buffer.unfinished_user_count > 0); } // Create a new buffer, add it to buffers_, and return a reference. - Buffer& NewBuffer(const HloInstruction* defining_instruction, int64 size, - std::vector&& users, bool live_out, - bool has_indirect_uses) { + Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users, + bool live_out, bool has_indirect_uses) { int buffer_id = buffers_.size(); buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out, has_indirect_uses, users, @@ -419,19 +491,7 @@ class MemoryUsageTracker { // The instruction currently being placed. This value is non-null only // between the calling of BeginInstruction and EndInstruction. - const HloInstruction* in_progress_instruction_ = nullptr; - - // The buffers defined by each instruction. - std::unordered_map> - buffers_defined_by_instruction_; - - // The buffers used by each instruction. - std::unordered_map> - buffers_used_by_instruction_; - - // The set of instructions which have been placed. That is, BeginInstruction - // has been called with the instruction as an argument. - tensorflow::gtl::FlatSet placed_instructions_; + Item* in_progress_item_ = nullptr; // All buffers in the computation. std::vector buffers_; @@ -443,22 +503,15 @@ MemoryUsageTracker::MemoryUsageTracker( const TuplePointsToAnalysis& points_to_analysis, const InstructionList& instruction_list) : computation_(computation), instruction_list_(instruction_list) { - // Iterate through all LogicalBuffers in the computation and gather the - // instructions which define them in buffers_defined_by_instruction_ and the - // instructions which use them in buffers_used_by_instruction_. - for (auto& instruction : computation_->instructions()) { - // Initialize empty vectors for defs and uses of each instruction. - buffers_used_by_instruction_[instruction.get()]; - buffers_defined_by_instruction_[instruction.get()]; - } - tensorflow::gtl::FlatSet live_out_set = points_to_analysis.GetPointsToSet(computation_->root_instruction()) .CreateFlattenedSet(); tensorflow::gtl::FlatMap logical_buffer_to_buffer_id; - for (const HloInstruction* instruction : instruction_list_.instructions()) { + for (auto* item = instruction_list_.first(); item != nullptr; + item = instruction_list_.next(item)) { + const HloInstruction* const instruction = item->instruction; for (const LogicalBuffer* logical_buffer : points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { Buffer* buffer; @@ -481,22 +534,22 @@ MemoryUsageTracker::MemoryUsageTracker( // Add users of while to Buffer users. bool unused; - for (const HloInstruction* user : - GetUsers(logical_buffer, points_to_analysis, &unused)) { - if (std::find(buffer->users.begin(), buffer->users.end(), user) == - buffer->users.end()) { - buffer->users.push_back(user); + for (Item* user_item : GetUsers(instruction_list_, logical_buffer, + points_to_analysis, &unused)) { + if (std::find(buffer->users.begin(), buffer->users.end(), + user_item) == buffer->users.end()) { + buffer->users.push_back(user_item); buffer->unfinished_user_count++; - buffers_used_by_instruction_.at(user).push_back(buffer->id); + user_item->buffers_used.push_back(buffer->id); } } } else { buffer = &CreateBufferFromLogicalBuffer( logical_buffer, points_to_analysis, size_function, ContainsKey(live_out_set, logical_buffer)); - buffers_defined_by_instruction_.at(instruction).push_back(buffer->id); - for (const HloInstruction* user : buffer->users) { - buffers_used_by_instruction_.at(user).push_back(buffer->id); + item->buffers_defined.push_back(buffer->id); + for (Item* user : buffer->users) { + user->buffers_used.push_back(buffer->id); } } @@ -507,15 +560,16 @@ MemoryUsageTracker::MemoryUsageTracker( DCHECK(Check()); } -Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) { +Status MemoryUsageTracker::BeginInstruction(Item* item) { + const HloInstruction* instruction = item->instruction; VLOG(3) << "BeginInstruction " << instruction->name(); - TF_RET_CHECK(in_progress_instruction_ == nullptr); - in_progress_instruction_ = instruction; + TF_RET_CHECK(in_progress_item_ == nullptr); + in_progress_item_ = item; - placed_instructions_.insert(in_progress_instruction_); + item->placed = true; // All buffers defined by this instruction need memory. - for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + for (BufferId buffer_id : item->buffers_defined) { VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString() << " is now live."; memory_usage_ += AllocatedSize(buffer_id); @@ -532,11 +586,10 @@ Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) { } Status MemoryUsageTracker::EndInstruction() { - TF_RET_CHECK(in_progress_instruction_ != nullptr); - VLOG(3) << "EndInstruction " << in_progress_instruction_->name(); + TF_RET_CHECK(in_progress_item_ != nullptr); + VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name(); - for (BufferId buffer_id : - buffers_used_by_instruction_.at(in_progress_instruction_)) { + for (BufferId buffer_id : in_progress_item_->buffers_used) { Buffer& buffer = buffers_.at(buffer_id); buffer.unfinished_user_count--; CHECK_GE(buffer.unfinished_user_count, 0) @@ -551,8 +604,7 @@ Status MemoryUsageTracker::EndInstruction() { // If any buffer defined by this instruction has no uses, then memory can be // reclaimed immediately. - for (BufferId buffer_id : - buffers_defined_by_instruction_.at(in_progress_instruction_)) { + for (BufferId buffer_id : in_progress_item_->buffers_defined) { const Buffer& buffer = buffers_.at(buffer_id); if (buffer.unfinished_user_count == 0) { VLOG(3) << " " << buffer.ToString() << " is immediately dead."; @@ -561,7 +613,7 @@ Status MemoryUsageTracker::EndInstruction() { } } - in_progress_instruction_ = nullptr; + in_progress_item_ = nullptr; VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); @@ -571,10 +623,9 @@ Status MemoryUsageTracker::EndInstruction() { return Status::OK(); } -int64 MemoryUsageTracker::MemoryReducedIfRematerialized( - const HloInstruction* instruction) const { - CHECK_NE(in_progress_instruction_, nullptr); - if (!IsPlaced(instruction) || instruction == in_progress_instruction_) { +int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const { + CHECK_NE(in_progress_item_, nullptr); + if (!item->placed || item == in_progress_item_) { return 0; } @@ -589,7 +640,7 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized( // be live at this program point, so initially set memory_reduced to the // size of its defined values. int64 memory_reduced = 0; - for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + for (BufferId buffer_id : item->buffers_defined) { // Avoid rematerializing instructions with indirect uses as it is difficult // to reason about liveness after rematerializing the instruction. // TODO(b/37714814): Consider rematerialzing instructions with indirect @@ -605,7 +656,7 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized( // Account for any logical buffers whose live range must be extended across // this program point. - for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) { + for (BufferId buffer_id : item->buffers_used) { if (!IsCurrentlyLive(buffer_id)) { // This logical buffer is used by 'instruction' but is not live at this // program point. Rematerializing 'instruction' will extend the buffer's @@ -617,28 +668,23 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized( return memory_reduced; } -Status MemoryUsageTracker::AddRematerializedInstruction( - HloInstruction* original_instruction, HloInstruction* remat_instruction) { +Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, + Item* remat_item) { VLOG(3) << "AddRematerializedInstruction: original_instruction = " - << original_instruction->name() - << ", remat_instruction = " << remat_instruction->name(); + << original_item->instruction->name() + << ", remat_instruction = " << remat_item->instruction->name(); - TF_RET_CHECK(in_progress_instruction_ != nullptr); - TF_RET_CHECK(IsPlaced(original_instruction)); - TF_RET_CHECK(!IsPlaced(remat_instruction)); - CHECK(!ContainsKey(buffers_defined_by_instruction_, remat_instruction)); - CHECK(!ContainsKey(buffers_used_by_instruction_, remat_instruction)); + TF_RET_CHECK(in_progress_item_ != nullptr); + TF_RET_CHECK(original_item->placed); + TF_RET_CHECK(!remat_item->placed); // Construct the list of buffers used and defined by the rematerialization. - buffers_defined_by_instruction_[remat_instruction]; - buffers_used_by_instruction_[remat_instruction] = - buffers_used_by_instruction_.at(original_instruction); + remat_item->buffers_used = original_item->buffers_used; // Account for the additional buffer uses created by the new rematerialization // instruction. Update memory usage if the rematerialization makes a dead // buffer live again. - for (BufferId buffer_id : - buffers_used_by_instruction_.at(original_instruction)) { + for (BufferId buffer_id : original_item->buffers_used) { Buffer& buffer = buffers_.at(buffer_id); if (buffer.unfinished_user_count == 0) { // Buffer used by this instruction was dead, now is alive. @@ -646,20 +692,19 @@ Status MemoryUsageTracker::AddRematerializedInstruction( } buffer.unfinished_user_count++; - buffer.users.push_back(remat_instruction); + buffer.users.push_back(remat_item); } // Create a new set of Buffers defined by the new rematerialization // instruction. Update the internal data structures and memory use to account // for them. - for (BufferId old_buffer_id : - buffers_defined_by_instruction_.at(original_instruction)) { + for (BufferId old_buffer_id : original_item->buffers_defined) { Buffer& old_buffer = buffers_.at(old_buffer_id); - std::vector placed_users; - std::vector unplaced_users; - for (const HloInstruction* user : old_buffer.users) { - if (IsPlaced(user)) { + ItemList placed_users; + ItemList unplaced_users; + for (Item* user : old_buffer.users) { + if (user->placed) { CHECK(IsFinished(user)); placed_users.push_back(user); } else { @@ -672,14 +717,12 @@ Status MemoryUsageTracker::AddRematerializedInstruction( // Buffer is now dead. memory_usage_ -= AllocatedSize(old_buffer.id); - Buffer& new_buffer = RematerializeBuffer(old_buffer, remat_instruction, - std::move(unplaced_users)); + Buffer& new_buffer = + RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users)); - buffers_defined_by_instruction_.at(remat_instruction) - .push_back(new_buffer.id); - for (const HloInstruction* user : new_buffer.users) { - std::vector& buffers_used = - buffers_used_by_instruction_.at(user); + remat_item->buffers_defined.push_back(new_buffer.id); + for (Item* user : new_buffer.users) { + BufferIdList& buffers_used = user->buffers_used; std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id, new_buffer.id); } @@ -699,13 +742,14 @@ string MemoryUsageTracker::ToString() const { tensorflow::strings::StrAppend( &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", memory_usage(), " bytes)"); - for (const HloInstruction* instruction : instruction_list_.instructions()) { - string inprogress = - instruction == in_progress_instruction_ ? " in-progress" : ""; - string placed = IsPlaced(instruction) ? " placed" : ""; + for (auto* item = instruction_list_.first(); item != nullptr; + item = instruction_list_.next(item)) { + const HloInstruction* instruction = item->instruction; + string inprogress = item == in_progress_item_ ? " in-progress" : ""; + string placed = item->placed ? " placed" : ""; tensorflow::strings::StrAppend(&output, " ", instruction->name(), inprogress, placed, "\n Defines:\n"); - for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + for (BufferId buffer_id : item->buffers_defined) { const Buffer& buffer = buffers_[buffer_id]; string live = IsCurrentlyLive(buffer_id) ? " live" : ""; tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live, @@ -713,7 +757,7 @@ string MemoryUsageTracker::ToString() const { " unfinished uses\n"); } tensorflow::strings::StrAppend(&output, " Uses:\n"); - for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) { + for (BufferId buffer_id : item->buffers_used) { tensorflow::strings::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n"); } @@ -722,14 +766,14 @@ string MemoryUsageTracker::ToString() const { } bool MemoryUsageTracker::Check() const { - auto elements_are_unique = [](const std::vector& vec) { + auto elements_are_unique = [](const BufferIdList& vec) { return vec.size() == std::set(vec.begin(), vec.end()).size(); }; - // Verify buffers_defined_by_instruction_. + // Verify buffers_defined per instruction. for (auto& instruction : computation_->instructions()) { - const std::vector& defined_buffers = - buffers_defined_by_instruction_.at(instruction.get()); + const BufferIdList& defined_buffers = + instruction_list_.GetItem(instruction.get())->buffers_defined; CHECK(elements_are_unique(defined_buffers)) << "Instruction " << instruction->name() << " does not have unique defined buffers: " @@ -740,7 +784,7 @@ bool MemoryUsageTracker::Check() const { }); for (const Buffer& buffer : buffers_) { - if (buffer.defining_instruction == instruction.get()) { + if (buffer.defining_instruction->instruction == instruction.get()) { CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), buffer.id) != defined_buffers.end()) << "Instruction " << instruction->name() @@ -749,10 +793,10 @@ bool MemoryUsageTracker::Check() const { } } - // Verify buffers_used_by_instruction_. + // Verify buffers_used per instruction. for (auto& instruction : computation_->instructions()) { - const std::vector& used_buffers = - buffers_used_by_instruction_.at(instruction.get()); + const BufferIdList& used_buffers = + instruction_list_.GetItem(instruction.get())->buffers_used; CHECK(elements_are_unique(used_buffers)) << "Instruction " << instruction->name() << " does not have unique used buffers: " @@ -764,13 +808,12 @@ bool MemoryUsageTracker::Check() const { } for (const Buffer& buffer : buffers_) { int64 unfinished_uses = 0; - for (const HloInstruction* user : buffer.users) { - const std::vector& used_buffers = - buffers_used_by_instruction_.at(user); + for (Item* user : buffer.users) { + const BufferIdList& used_buffers = user->buffers_used; CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) != used_buffers.end()) - << "Instruction " << user->name() << " used buffers is missing " - << buffer.ToString(); + << "Instruction " << user->instruction->name() + << " used buffers is missing " << buffer.ToString(); if (!IsFinished(user)) { unfinished_uses++; } @@ -785,8 +828,8 @@ bool MemoryUsageTracker::Check() const { // The while instruction reuses its input buffers as output buffers so // don't double count its buffers if it is currently executing. if (IsCurrentlyLive(buffer.id) && - !(buffer.defining_instruction == in_progress_instruction_ && - in_progress_instruction_->opcode() == HloOpcode::kWhile)) { + !(buffer.defining_instruction == in_progress_item_ && + in_progress_item_->instruction->opcode() == HloOpcode::kWhile)) { live_size += AllocatedSize(buffer.id); } } @@ -830,26 +873,26 @@ int64 RematerializationCost(const HloInstruction* instruction, // candidate which reduce memory use at the program point of the current // instruction as indicated by memory_tracker. nullptr is returned if no // candidate can be found. -HloInstruction* PickRematerializationCandidate( - const MemoryUsageTracker& memory_tracker, - const InstructionList& instruction_list, - const tensorflow::gtl::FlatSet& blacklist, - int64 memory_limit_bytes) { - HloInstruction* best = nullptr; +Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, + const InstructionList& instruction_list, + int64 memory_limit_bytes) { + Item* best_item = nullptr; int64 best_cost = 0; // TODO(b/35244891): This is currently quadratic in the number of HLO // instructions. - for (HloInstruction* candidate : instruction_list.instructions()) { - if (!memory_tracker.IsPlaced(candidate)) { - // Only iterate up to the currently placed instruction as indicated by - // memory_tracker. We are trying to reduce memory usage at the placed + for (auto* item = instruction_list.first(); item != nullptr; + item = instruction_list.next(item)) { + if (!item->placed) { + // Only iterate up to the currently placed instruction. + // We are trying to reduce memory usage at the placed // instruction so rematerializing later values is of no benefit. break; } + HloInstruction* candidate = item->instruction; VLOG(5) << "considering rematerialization candidate " << candidate->name(); - if (ContainsKey(blacklist, candidate)) { + if (item->blacklisted) { // Skip instructions on the blacklist to avoid infinite loops of // rematerializing the same instruction(s) repeatedly. VLOG(5) << "candidate " << candidate->name() @@ -864,7 +907,7 @@ HloInstruction* PickRematerializationCandidate( } const int64 memory_reduced = - memory_tracker.MemoryReducedIfRematerialized(candidate); + memory_tracker.MemoryReducedIfRematerialized(item); if (memory_reduced <= 0) { VLOG(5) << "candidate " << candidate->name() @@ -878,13 +921,13 @@ HloInstruction* PickRematerializationCandidate( VLOG(5) << "candidate " << candidate->name() << ", memory reduced " << memory_reduced << ", cost per byte " << cost; - if (best == nullptr || cost < best_cost) { + if (best_item == nullptr || cost < best_cost) { VLOG(5) << "candidate " << candidate->name() << " now best"; - best = candidate; + best_item = item; best_cost = cost; } } - return best; + return best_item; } } // namespace @@ -896,8 +939,10 @@ StatusOr HloRematerialization::ComputePeakMemory( MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, instruction_list); int64 peak_memory = tracker.memory_usage(); - for (const HloInstruction* instruction : order) { - TF_RETURN_IF_ERROR(tracker.BeginInstruction(instruction)); + for (auto* item = instruction_list.first(); item != nullptr; + item = instruction_list.next(item)) { + const HloInstruction* instruction = item->instruction; + TF_RETURN_IF_ERROR(tracker.BeginInstruction(item)); TF_ASSIGN_OR_RETURN(int64 callee_usage, CalledComputationsMemoryUsage(instruction)); peak_memory = @@ -939,11 +984,6 @@ StatusOr HloRematerialization::RematerializeComputation( *points_to_analysis_, instruction_list); bool changed = false; - // To avoid an infinite loop rematerializing the same set of instructions ad - // infinitum, keep a blacklist of instructions which should not be - // rematerialized. - tensorflow::gtl::FlatSet blacklist; - // If the rematerialization makes the source instruction dead, then the // rematerialization is added to 'remat_move_instructions' (the // rematerialization is essentially a move). If the next rematerialization of @@ -967,17 +1007,17 @@ StatusOr HloRematerialization::RematerializeComputation( // (program point) if memory_usage exceeds the specified limit then // rematerialize HLO instructions until memory_usage is reduced. int64 instruction_index = 0; - for (auto list_it = instruction_list.instructions().begin(); - list_it != instruction_list.instructions().end(); ++list_it) { - HloInstruction* instruction = *list_it; + for (auto* item = instruction_list.first(); item != nullptr; + item = instruction_list.next(item)) { + const HloInstruction* instruction = item->instruction; TF_ASSIGN_OR_RETURN(int64 callee_usage, CalledComputationsMemoryUsage(instruction)); - TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(instruction)); + TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item)); VLOG(2) << "Program point at " << instruction->name() << ", memory usage = " << memory_tracker.memory_usage() << ", callee usage = " << callee_usage << ", [" << instruction_index - << "/" << instruction_list.instructions().size() << "]"; + << "/" << instruction_list.size() << "]"; instruction_index++; while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { @@ -987,10 +1027,10 @@ StatusOr HloRematerialization::RematerializeComputation( callee_usage) << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); - HloInstruction* best = PickRematerializationCandidate( - memory_tracker, instruction_list, blacklist, memory_limit_bytes); + Item* best_item = PickRematerializationCandidate( + memory_tracker, instruction_list, memory_limit_bytes); - if (best == nullptr) { + if (best_item == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " "point " << instruction->name() << ". Memory usage = " @@ -999,13 +1039,15 @@ StatusOr HloRematerialization::RematerializeComputation( break; } + HloInstruction* best = best_item->instruction; VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " - << memory_tracker.MemoryReducedIfRematerialized(best) << ")"; + << memory_tracker.MemoryReducedIfRematerialized(best_item) << ")"; changed = true; remat_count++; HloInstruction* remat = computation->AddInstruction(best->Clone(/*suffix=*/"remat")); + Item* remat_item = instruction_list.CreateItem(remat); // Replace each remaining use of 'best' with the rematerialization. std::vector best_users_copy = best->users(); @@ -1019,22 +1061,28 @@ StatusOr HloRematerialization::RematerializeComputation( // Account for the rematerialization in the memory tracker. TF_RETURN_IF_ERROR( - memory_tracker.AddRematerializedInstruction(best, remat)); + memory_tracker.AddRematerializedInstruction(best_item, remat_item)); // Insert rematerialized instruction right before the earliest unplaced // use of the instruction *and* the earliest unplaced last use of any // operands of remat. Unplaced uses of the remat's operands are included // because we don't want to extend the live range of remat's operands as // this could increase memory usage. - std::vector place_before = remat->users(); + ItemList place_before; + for (auto user : remat->users()) { + place_before.push_back(instruction_list.GetItem(user)); + } for (auto* operand : remat->operands()) { for (auto* operand_user : operand->users()) { - if (!memory_tracker.IsPlaced(operand_user) && operand_user != remat) { - place_before.push_back(operand_user); + if (operand_user != remat) { + Item* operand_user_item = instruction_list.GetItem(operand_user); + if (!operand_user_item->placed) { + place_before.push_back(operand_user_item); + } } } } - instruction_list.InsertBeforeInstructions(remat, place_before); + instruction_list.InsertBeforeInstructions(remat_item, place_before); // If the rematerialized instruction is dead then rematerialization is // essentially a move. Don't delete the instruction now because we don't @@ -1048,7 +1096,7 @@ StatusOr HloRematerialization::RematerializeComputation( // instruction it was a copying of. Now 'remat' is a rematerialization // of 'best' and kills 'best'. Stop rematerializing this instruction // to avoid an infinite loop. - blacklist.insert(remat); + instruction_list.Blacklist(remat); } remat_move_instructions.insert(remat); } else { @@ -1116,10 +1164,13 @@ StatusOr HloRematerialization::RematerializeComputation( computation_peak_memory_.at(computation) = peak_memory; // Update order to include rematerialized instructions. - sequence->at(computation) - .assign(instruction_list.instructions().begin(), - instruction_list.instructions().end()); - + auto& dst = sequence->at(computation); + dst.clear(); + for (auto* item = instruction_list.first(); item != nullptr; + item = instruction_list.next(item)) { + const HloInstruction* instruction = item->instruction; + dst.push_back(instruction); + } rematerialized_computations_.insert(computation); instructions_rematerialized_ += remat_count;