rollback a broken CL

Reverts 5bddc28fd5

PiperOrigin-RevId: 800789392
This commit is contained in:
A. Unique TensorFlower 2025-08-29 01:07:08 -07:00 committed by TensorFlower Gardener
parent d991543c54
commit 22032a9edb
3 changed files with 97 additions and 76 deletions

View File

@ -296,11 +296,12 @@ cc_library(
"//xla:util", "//xla:util",
"//xla/hlo/analysis:alias_info", "//xla/hlo/analysis:alias_info",
"//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/analysis:hlo_alias_analysis",
"//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/analysis:tuple_points_to_analysis",
"//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass",
"//xla/service:buffer_value", "//xla/service:buffer_value",
"//xla/service:hlo_value", "//xla/service:hlo_value",
"//xla/service:logical_buffer",
"//xla/service/heap_simulator", "//xla/service/heap_simulator",
"//xla/tsl/platform:errors", "//xla/tsl/platform:errors",
"//xla/tsl/platform:logging", "//xla/tsl/platform:logging",

View File

@ -36,7 +36,7 @@ limitations under the License.
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "xla/hlo/analysis/alias_info.h" #include "xla/hlo/analysis/alias_info.h"
#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/analysis/hlo_alias_analysis.h"
#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/analysis/tuple_points_to_analysis.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instruction.h"
@ -45,6 +45,7 @@ limitations under the License.
#include "xla/service/buffer_value.h" #include "xla/service/buffer_value.h"
#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/heap_simulator/heap_simulator.h"
#include "xla/service/hlo_value.h" #include "xla/service/hlo_value.h"
#include "xla/service/logical_buffer.h"
#include "xla/shape_util.h" #include "xla/shape_util.h"
#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/logging.h"
@ -92,9 +93,10 @@ class ListScheduler {
// Construct and return a memory-minimizing sequence of HLO instructions // Construct and return a memory-minimizing sequence of HLO instructions
// containing the given HLO computation. // containing the given HLO computation.
static absl::StatusOr<HloInstructionSequence> Run( static absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation, const HloAliasAnalysis& alias_analysis, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_function) { const BufferValue::SizeFunction& size_function) {
ListScheduler scheduler(computation, alias_analysis, size_function); ListScheduler scheduler(computation, points_to_analysis, size_function);
return scheduler.CreateSchedule(); return scheduler.CreateSchedule();
} }
@ -114,57 +116,56 @@ class ListScheduler {
using Priority = std::pair<int64_t, int64_t>; using Priority = std::pair<int64_t, int64_t>;
ListScheduler(HloComputation* computation, ListScheduler(HloComputation* computation,
const HloAliasAnalysis& alias_analysis, const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_function) const BufferValue::SizeFunction& size_function)
: computation_(computation), : computation_(computation),
alias_analysis_(alias_analysis), points_to_analysis_(points_to_analysis),
size_function_(size_function) { size_function_(size_function) {
// Create a map containing the HloValue uses for each HLO instruction. An // Create a map containing the LogicalBuffer uses for each HLO
// HLO instruction "uses" a HloValue if the HloValue is in an operand of the // instruction. An HLO instruction "uses" a LogicalBuffer if the
// instruction as indicated by HloDataflowAnalysis. // LogicalBuffer is in an operand of the instruction as indicated by
const HloDataflowAnalysis& dataflow_analysis = // points-to analysis.
alias_analysis.dataflow_analysis();
absl::flat_hash_set<const HloValue*> computation_values;
for (auto* instruction : computation->instructions()) { for (auto* instruction : computation->instructions()) {
dataflow_analysis.GetInstructionValueSet(instruction) absl::flat_hash_set<const LogicalBuffer*> instr_uses;
.ForEachElement(
[&](const ShapeIndex& /*index*/, const HloValueSet& value_set) {
computation_values.insert(value_set.values().begin(),
value_set.values().end());
});
absl::flat_hash_set<const HloValue*> instr_uses;
for (auto* operand : instruction->operands()) { for (auto* operand : instruction->operands()) {
dataflow_analysis.GetInstructionValueSet(operand).ForEachElement( points_to_analysis.GetPointsToSet(operand).ForEachElement(
[&](const ShapeIndex& /*index*/, const HloValueSet& value_set) { [&](const ShapeIndex& /*index*/,
instr_uses.insert(value_set.values().begin(), const PointsToSet::BufferList& buffers) {
value_set.values().end()); instr_uses.insert(buffers.begin(), buffers.end());
}); });
} }
value_uses_[instruction] = buffer_uses_[instruction] = std::vector<const LogicalBuffer*>(
std::vector<const HloValue*>(instr_uses.begin(), instr_uses.end()); instr_uses.begin(), instr_uses.end());
} }
// Create map containing the number of unscheduled uses (hlo instructions) // Create map containing the number of unscheduled uses (hlo instructions)
// of each HloValue. // of each logical buffer.
unscheduled_use_count_.reserve(computation_values.size()); unscheduled_use_count_.reserve(points_to_analysis.num_logical_buffers());
for (const HloValue* value : computation_values) { for (auto* instruction : computation->instructions()) {
// HloValues live out of the computation have an implicit use at the end for (auto* buffer :
// of the computation. Therefore we initialize `unscheduled_use_count_` to points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
// 1 in such cases. unscheduled_use_count_[buffer] = 0;
unscheduled_use_count_[value] = }
alias_analysis.ValueLivesOut(*value) ? 1 : 0;
} }
for (auto* instruction : computation->instructions()) { for (auto* instruction : computation->instructions()) {
for (const HloValue* value : value_uses_.at(instruction)) { for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
++unscheduled_use_count_[value]; ++unscheduled_use_count_[buffer];
} }
} }
// Buffers live out of the computation have an implicit use at the end of
// the computation.
for (const LogicalBuffer* live_out_buffer :
points_to_analysis.GetPointsToSet(computation->root_instruction())
.CreateFlattenedSet()) {
++unscheduled_use_count_[live_out_buffer];
}
} }
// Returns whether the memory used by the given HloValue should be ignored by // Returns whether the memory used by the given buffer should be ignored by
// the scheduling heuristic. // the scheduling heuristic.
static bool IgnoreValue(const HloValue& value) { static bool IgnoreBuffer(const LogicalBuffer& buffer) {
return IgnoreInstruction(*value.instruction()); return IgnoreInstruction(*buffer.instruction());
} }
// An entry in the worklist used by CreateSchedule. Corresponds to one // An entry in the worklist used by CreateSchedule. Corresponds to one
@ -180,7 +181,7 @@ class ListScheduler {
// U is the number of uses of B that have not yet been scheduled. This pair // U is the number of uses of B that have not yet been scheduled. This pair
// is a pointer into the unscheduled_use_count_ map, so it gets updated for // is a pointer into the unscheduled_use_count_ map, so it gets updated for
// free when we update counts in the map. // free when we update counts in the map.
std::vector<const std::pair<const HloValue* const, int64_t>*> std::vector<const std::pair<const LogicalBuffer* const, int64_t>*>
used_buffer_unscheduled_use_counts; used_buffer_unscheduled_use_counts;
}; };
@ -190,20 +191,18 @@ class ListScheduler {
entry.instruction = instruction; entry.instruction = instruction;
entry.bytes_defined = 0; entry.bytes_defined = 0;
HloValueSet value_set = for (auto* buffer :
alias_analysis_.dataflow_analysis().GetFlattenedValueSet(instruction); points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
for (const HloValue* value : value_set.values()) { if (!IgnoreBuffer(*buffer)) {
if (!IgnoreInstruction(*instruction) && entry.bytes_defined += size_function_(*buffer);
value->instruction() == instruction) {
entry.bytes_defined += size_function_(*value);
} }
} }
for (auto* value : value_uses_.at(instruction)) { for (auto* buffer : buffer_uses_.at(instruction)) {
if (IgnoreValue(*value)) { if (IgnoreBuffer(*buffer)) {
continue; continue;
} }
auto unscheduled_use_count_it = unscheduled_use_count_.find(value); auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
CHECK(unscheduled_use_count_it != unscheduled_use_count_.end()); CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
entry.used_buffer_unscheduled_use_counts.push_back( entry.used_buffer_unscheduled_use_counts.push_back(
&*unscheduled_use_count_it); &*unscheduled_use_count_it);
@ -309,9 +308,9 @@ class ListScheduler {
scheduled_instructions_.insert(best); scheduled_instructions_.insert(best);
bool adjust_ready_queue = false; bool adjust_ready_queue = false;
// Update the unscheduled uses of the HloValues. // Update the unscheduled uses of the logical buffers.
for (const HloValue* value : value_uses_.at(best)) { for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
int64_t& count = unscheduled_use_count_[value]; int64_t& count = unscheduled_use_count_[buffer];
CHECK_GT(count, 0); CHECK_GT(count, 0);
--count; --count;
if (count == 1) { if (count == 1) {
@ -335,7 +334,7 @@ class ListScheduler {
for (HloInstruction* succ : best->control_successors()) { for (HloInstruction* succ : best->control_successors()) {
update_pred_count(succ); update_pred_count(succ);
} }
// The unscheduled use count for a HloValue has changed to 1, so the // The unscheduled use count for a buffer has changed to 1, so the
// priorities of some ready instructions may go up. We update them in the // priorities of some ready instructions may go up. We update them in the
// ready queue, so that they can appear earlier. // ready queue, so that they can appear earlier.
if (adjust_ready_queue) { if (adjust_ready_queue) {
@ -361,24 +360,23 @@ class ListScheduler {
} }
} }
} }
CHECK_EQ(schedule.size(), computation_->instruction_count()) CHECK_EQ(schedule.size(), computation_->instruction_count());
<< "There could be a cycle in the HLO graph";
CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count()); CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
return schedule; return schedule;
} }
HloComputation* computation_; HloComputation* computation_;
const HloAliasAnalysis& alias_analysis_; const TuplePointsToAnalysis& points_to_analysis_;
const BufferValue::SizeFunction& size_function_; const BufferValue::SizeFunction& size_function_;
// A map containing the HloValue that each instruction uses. // A map containing the LogicalBuffers that each instruction uses.
absl::flat_hash_map<const HloInstruction*, std::vector<const HloValue*>> absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
value_uses_; buffer_uses_;
// A map containing the count of unscheduled HLOs which using a particular // A map containing the count of unscheduled HLOs which using a particular
// HloValue. // LogicalBuffer.
absl::flat_hash_map<const HloValue*, int64_t> unscheduled_use_count_; absl::flat_hash_map<const LogicalBuffer*, int64_t> unscheduled_use_count_;
// Set of instructions which have been scheduled. // Set of instructions which have been scheduled.
absl::flat_hash_set<const HloInstruction*> scheduled_instructions_; absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
@ -398,7 +396,8 @@ int64_t SumBufferSizes(const HloInstruction* hlo, const HloValueSet& value_set,
} // namespace } // namespace
absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run( absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
const HloModule* module, const HloAliasAnalysis& alias_analysis, const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const absl::flat_hash_set<absl::string_view>& execution_threads, const absl::flat_hash_set<absl::string_view>& execution_threads,
int64_t* peak_memory) const { int64_t* peak_memory) const {
HloSchedule schedule(module); HloSchedule schedule(module);
@ -406,7 +405,7 @@ absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
module->MakeComputationPostOrder(execution_threads)) { module->MakeComputationPostOrder(execution_threads)) {
if (!computation->IsFusionComputation()) { if (!computation->IsFusionComputation()) {
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
Run(computation, alias_analysis)); Run(computation, points_to_analysis, alias_analysis));
if (postprocessor_) { if (postprocessor_) {
computation_sequence = postprocessor_(computation_sequence); computation_sequence = postprocessor_(computation_sequence);
} }
@ -423,6 +422,7 @@ absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
absl::StatusOr<HloInstructionSequence> DFSMemoryScheduler::Run( absl::StatusOr<HloInstructionSequence> DFSMemoryScheduler::Run(
HloComputation* computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const { const HloAliasAnalysis& alias_analysis) const {
// These variables are a hack to prevent overflows. // These variables are a hack to prevent overflows.
int64_t cumulative_total_size = 0; int64_t cumulative_total_size = 0;
@ -501,6 +501,7 @@ absl::StatusOr<HloInstructionSequence> DFSMemoryScheduler::Run(
absl::StatusOr<HloInstructionSequence> BFScheduler::Run( absl::StatusOr<HloInstructionSequence> BFScheduler::Run(
HloComputation* computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const { const HloAliasAnalysis& alias_analysis) const {
// Index of HloInstruction in the `computation`. // Index of HloInstruction in the `computation`.
absl::flat_hash_map<const HloInstruction*, int64_t> inst_index; absl::flat_hash_map<const HloInstruction*, int64_t> inst_index;
@ -554,18 +555,21 @@ absl::StatusOr<HloInstructionSequence> BFScheduler::Run(
absl::StatusOr<HloInstructionSequence> ListMemoryScheduler::Run( absl::StatusOr<HloInstructionSequence> ListMemoryScheduler::Run(
HloComputation* computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const { const HloAliasAnalysis& alias_analysis) const {
return ListScheduler::Run(computation, alias_analysis, size_function_); return ListScheduler::Run(computation, points_to_analysis, size_function_);
} }
absl::StatusOr<HloInstructionSequence> PostOrderScheduler::Run( absl::StatusOr<HloInstructionSequence> PostOrderScheduler::Run(
HloComputation* computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const { const HloAliasAnalysis& alias_analysis) const {
return HloInstructionSequence(computation->MakeInstructionPostOrder()); return HloInstructionSequence(computation->MakeInstructionPostOrder());
} }
absl::StatusOr<HloSchedule> DefaultMemoryScheduler::Run( absl::StatusOr<HloSchedule> DefaultMemoryScheduler::Run(
const HloModule* module, const HloAliasAnalysis& alias_analysis, const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const absl::flat_hash_set<absl::string_view>& execution_threads, const absl::flat_hash_set<absl::string_view>& execution_threads,
int64_t* peak_memory) const { int64_t* peak_memory) const {
// We try a few schedulers and choose whichever returns a lower min-memory, // We try a few schedulers and choose whichever returns a lower min-memory,
@ -577,22 +581,24 @@ absl::StatusOr<HloSchedule> DefaultMemoryScheduler::Run(
// List wins for most of our benchmarks; postorder-based schedulers win for // List wins for most of our benchmarks; postorder-based schedulers win for
// some RNNs. // some RNNs.
int64_t list_memory; int64_t list_memory;
TF_ASSIGN_OR_RETURN(HloSchedule list_sequence, TF_ASSIGN_OR_RETURN(
list_scheduler_.Run(module, alias_analysis, HloSchedule list_sequence,
execution_threads, &list_memory)); list_scheduler_.Run(module, points_to_analysis, alias_analysis,
execution_threads, &list_memory));
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
int64_t dfs_memory; int64_t dfs_memory;
TF_ASSIGN_OR_RETURN(HloSchedule dfs_sequence, TF_ASSIGN_OR_RETURN(
dfs_scheduler_.Run(module, alias_analysis, HloSchedule dfs_sequence,
execution_threads, &dfs_memory)); dfs_scheduler_.Run(module, points_to_analysis, alias_analysis,
execution_threads, &dfs_memory));
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
int64_t post_order_memory; int64_t post_order_memory;
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
HloSchedule post_order_sequence, HloSchedule post_order_sequence,
post_order_scheduler_.Run(module, alias_analysis, execution_threads, post_order_scheduler_.Run(module, points_to_analysis, alias_analysis,
&post_order_memory)); execution_threads, &post_order_memory));
VLOG(2) << "Min-memory post order sequence: " VLOG(2) << "Min-memory post order sequence: "
<< HumanReadableNumBytes(post_order_memory); << HumanReadableNumBytes(post_order_memory);
@ -624,12 +630,15 @@ absl::StatusOr<HloSchedule> ScheduleModule(
return absl::StrFormat("XlaMemoryScheduler:#module=%s,program_id=%d#", return absl::StrFormat("XlaMemoryScheduler:#module=%s,program_id=%d#",
module->name(), module->unique_id()); module->name(), module->unique_id());
}); });
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(module));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, algorithm.alias_info())); HloAliasAnalysis::Run(module, algorithm.alias_info()));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
HloSchedule schedule, HloSchedule schedule,
algorithm.Run(module, *alias_analysis, execution_threads, peak_memory)); algorithm.Run(module, *points_to_analysis, *alias_analysis,
execution_threads, peak_memory));
TF_RETURN_IF_ERROR(schedule.Verify()); TF_RETURN_IF_ERROR(schedule.Verify());

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "xla/hlo/analysis/alias_info.h" #include "xla/hlo/analysis/alias_info.h"
#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/analysis/hlo_alias_analysis.h"
#include "xla/hlo/analysis/tuple_points_to_analysis.h"
#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_schedule.h"
@ -38,13 +39,16 @@ namespace xla {
// 'module' given a points-to analysis result that describes buffer aliasing. // 'module' given a points-to analysis result that describes buffer aliasing.
// peak_memory (may be nullptr) is set to the peak memory of the resulting // peak_memory (may be nullptr) is set to the peak memory of the resulting
// schedule according to the HeapSimulator. // schedule according to the HeapSimulator.
//
// TODO(yunxing): Cleanup usage of TuplePointsToAnalysis.
class ModuleSchedulerAlgorithm { class ModuleSchedulerAlgorithm {
public: public:
explicit ModuleSchedulerAlgorithm(const AliasInfo* alias_info) explicit ModuleSchedulerAlgorithm(const AliasInfo* alias_info)
: alias_info_(alias_info) {} : alias_info_(alias_info) {}
virtual ~ModuleSchedulerAlgorithm() = default; virtual ~ModuleSchedulerAlgorithm() = default;
virtual absl::StatusOr<HloSchedule> Run( virtual absl::StatusOr<HloSchedule> Run(
const HloModule* module, const HloAliasAnalysis& alias_analysis, const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const absl::flat_hash_set<absl::string_view>& execution_threads, const absl::flat_hash_set<absl::string_view>& execution_threads,
int64_t* peak_memory) const = 0; int64_t* peak_memory) const = 0;
@ -68,9 +72,11 @@ class ComputationSchedulerAlgorithm : public ModuleSchedulerAlgorithm {
public: public:
virtual absl::StatusOr<HloInstructionSequence> Run( virtual absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const = 0; const HloAliasAnalysis& alias_analysis) const = 0;
absl::StatusOr<HloSchedule> Run( absl::StatusOr<HloSchedule> Run(
const HloModule* module, const HloAliasAnalysis& alias_analysis, const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const absl::flat_hash_set<absl::string_view>& execution_threads, const absl::flat_hash_set<absl::string_view>& execution_threads,
int64_t* peak_memory) const override; int64_t* peak_memory) const override;
@ -99,6 +105,7 @@ class ListMemoryScheduler : public ComputationSchedulerAlgorithm {
using ModuleSchedulerAlgorithm::Run; using ModuleSchedulerAlgorithm::Run;
absl::StatusOr<HloInstructionSequence> Run( absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const override; const HloAliasAnalysis& alias_analysis) const override;
}; };
@ -113,6 +120,7 @@ class DFSMemoryScheduler : public ComputationSchedulerAlgorithm {
using ModuleSchedulerAlgorithm::Run; using ModuleSchedulerAlgorithm::Run;
absl::StatusOr<HloInstructionSequence> Run( absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const override; const HloAliasAnalysis& alias_analysis) const override;
}; };
@ -135,6 +143,7 @@ class BFScheduler : public ComputationSchedulerAlgorithm {
std::move(postprocessor)) {} std::move(postprocessor)) {}
absl::StatusOr<HloInstructionSequence> Run( absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const override; const HloAliasAnalysis& alias_analysis) const override;
}; };
@ -149,6 +158,7 @@ class PostOrderScheduler : public ComputationSchedulerAlgorithm {
using ModuleSchedulerAlgorithm::Run; using ModuleSchedulerAlgorithm::Run;
absl::StatusOr<HloInstructionSequence> Run( absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation, HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const override; const HloAliasAnalysis& alias_analysis) const override;
}; };
@ -166,7 +176,8 @@ class DefaultMemoryScheduler : public ModuleSchedulerAlgorithm {
dfs_scheduler_(alias_info, size_function, postprocessor), dfs_scheduler_(alias_info, size_function, postprocessor),
post_order_scheduler_(alias_info, size_function, postprocessor) {} post_order_scheduler_(alias_info, size_function, postprocessor) {}
absl::StatusOr<HloSchedule> Run( absl::StatusOr<HloSchedule> Run(
const HloModule* module, const HloAliasAnalysis& alias_analysis, const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const absl::flat_hash_set<absl::string_view>& execution_threads, const absl::flat_hash_set<absl::string_view>& execution_threads,
int64_t* peak_memory) const override; int64_t* peak_memory) const override;