mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
parent
d991543c54
commit
22032a9edb
|
|
@ -296,11 +296,12 @@ cc_library(
|
||||||
"//xla:util",
|
"//xla:util",
|
||||||
"//xla/hlo/analysis:alias_info",
|
"//xla/hlo/analysis:alias_info",
|
||||||
"//xla/hlo/analysis:hlo_alias_analysis",
|
"//xla/hlo/analysis:hlo_alias_analysis",
|
||||||
"//xla/hlo/analysis:hlo_dataflow_analysis",
|
"//xla/hlo/analysis:tuple_points_to_analysis",
|
||||||
"//xla/hlo/ir:hlo",
|
"//xla/hlo/ir:hlo",
|
||||||
"//xla/hlo/pass:hlo_pass",
|
"//xla/hlo/pass:hlo_pass",
|
||||||
"//xla/service:buffer_value",
|
"//xla/service:buffer_value",
|
||||||
"//xla/service:hlo_value",
|
"//xla/service:hlo_value",
|
||||||
|
"//xla/service:logical_buffer",
|
||||||
"//xla/service/heap_simulator",
|
"//xla/service/heap_simulator",
|
||||||
"//xla/tsl/platform:errors",
|
"//xla/tsl/platform:errors",
|
||||||
"//xla/tsl/platform:logging",
|
"//xla/tsl/platform:logging",
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ limitations under the License.
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "xla/hlo/analysis/alias_info.h"
|
#include "xla/hlo/analysis/alias_info.h"
|
||||||
#include "xla/hlo/analysis/hlo_alias_analysis.h"
|
#include "xla/hlo/analysis/hlo_alias_analysis.h"
|
||||||
#include "xla/hlo/analysis/hlo_dataflow_analysis.h"
|
#include "xla/hlo/analysis/tuple_points_to_analysis.h"
|
||||||
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
|
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
|
||||||
#include "xla/hlo/ir/hlo_computation.h"
|
#include "xla/hlo/ir/hlo_computation.h"
|
||||||
#include "xla/hlo/ir/hlo_instruction.h"
|
#include "xla/hlo/ir/hlo_instruction.h"
|
||||||
|
|
@ -45,6 +45,7 @@ limitations under the License.
|
||||||
#include "xla/service/buffer_value.h"
|
#include "xla/service/buffer_value.h"
|
||||||
#include "xla/service/heap_simulator/heap_simulator.h"
|
#include "xla/service/heap_simulator/heap_simulator.h"
|
||||||
#include "xla/service/hlo_value.h"
|
#include "xla/service/hlo_value.h"
|
||||||
|
#include "xla/service/logical_buffer.h"
|
||||||
#include "xla/shape_util.h"
|
#include "xla/shape_util.h"
|
||||||
#include "xla/tsl/platform/errors.h"
|
#include "xla/tsl/platform/errors.h"
|
||||||
#include "xla/tsl/platform/logging.h"
|
#include "xla/tsl/platform/logging.h"
|
||||||
|
|
@ -92,9 +93,10 @@ class ListScheduler {
|
||||||
// Construct and return a memory-minimizing sequence of HLO instructions
|
// Construct and return a memory-minimizing sequence of HLO instructions
|
||||||
// containing the given HLO computation.
|
// containing the given HLO computation.
|
||||||
static absl::StatusOr<HloInstructionSequence> Run(
|
static absl::StatusOr<HloInstructionSequence> Run(
|
||||||
HloComputation* computation, const HloAliasAnalysis& alias_analysis,
|
HloComputation* computation,
|
||||||
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const BufferValue::SizeFunction& size_function) {
|
const BufferValue::SizeFunction& size_function) {
|
||||||
ListScheduler scheduler(computation, alias_analysis, size_function);
|
ListScheduler scheduler(computation, points_to_analysis, size_function);
|
||||||
return scheduler.CreateSchedule();
|
return scheduler.CreateSchedule();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -114,57 +116,56 @@ class ListScheduler {
|
||||||
using Priority = std::pair<int64_t, int64_t>;
|
using Priority = std::pair<int64_t, int64_t>;
|
||||||
|
|
||||||
ListScheduler(HloComputation* computation,
|
ListScheduler(HloComputation* computation,
|
||||||
const HloAliasAnalysis& alias_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const BufferValue::SizeFunction& size_function)
|
const BufferValue::SizeFunction& size_function)
|
||||||
: computation_(computation),
|
: computation_(computation),
|
||||||
alias_analysis_(alias_analysis),
|
points_to_analysis_(points_to_analysis),
|
||||||
size_function_(size_function) {
|
size_function_(size_function) {
|
||||||
// Create a map containing the HloValue uses for each HLO instruction. An
|
// Create a map containing the LogicalBuffer uses for each HLO
|
||||||
// HLO instruction "uses" a HloValue if the HloValue is in an operand of the
|
// instruction. An HLO instruction "uses" a LogicalBuffer if the
|
||||||
// instruction as indicated by HloDataflowAnalysis.
|
// LogicalBuffer is in an operand of the instruction as indicated by
|
||||||
const HloDataflowAnalysis& dataflow_analysis =
|
// points-to analysis.
|
||||||
alias_analysis.dataflow_analysis();
|
|
||||||
absl::flat_hash_set<const HloValue*> computation_values;
|
|
||||||
for (auto* instruction : computation->instructions()) {
|
for (auto* instruction : computation->instructions()) {
|
||||||
dataflow_analysis.GetInstructionValueSet(instruction)
|
absl::flat_hash_set<const LogicalBuffer*> instr_uses;
|
||||||
.ForEachElement(
|
|
||||||
[&](const ShapeIndex& /*index*/, const HloValueSet& value_set) {
|
|
||||||
computation_values.insert(value_set.values().begin(),
|
|
||||||
value_set.values().end());
|
|
||||||
});
|
|
||||||
absl::flat_hash_set<const HloValue*> instr_uses;
|
|
||||||
for (auto* operand : instruction->operands()) {
|
for (auto* operand : instruction->operands()) {
|
||||||
dataflow_analysis.GetInstructionValueSet(operand).ForEachElement(
|
points_to_analysis.GetPointsToSet(operand).ForEachElement(
|
||||||
[&](const ShapeIndex& /*index*/, const HloValueSet& value_set) {
|
[&](const ShapeIndex& /*index*/,
|
||||||
instr_uses.insert(value_set.values().begin(),
|
const PointsToSet::BufferList& buffers) {
|
||||||
value_set.values().end());
|
instr_uses.insert(buffers.begin(), buffers.end());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
value_uses_[instruction] =
|
buffer_uses_[instruction] = std::vector<const LogicalBuffer*>(
|
||||||
std::vector<const HloValue*>(instr_uses.begin(), instr_uses.end());
|
instr_uses.begin(), instr_uses.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create map containing the number of unscheduled uses (hlo instructions)
|
// Create map containing the number of unscheduled uses (hlo instructions)
|
||||||
// of each HloValue.
|
// of each logical buffer.
|
||||||
unscheduled_use_count_.reserve(computation_values.size());
|
unscheduled_use_count_.reserve(points_to_analysis.num_logical_buffers());
|
||||||
for (const HloValue* value : computation_values) {
|
for (auto* instruction : computation->instructions()) {
|
||||||
// HloValues live out of the computation have an implicit use at the end
|
for (auto* buffer :
|
||||||
// of the computation. Therefore we initialize `unscheduled_use_count_` to
|
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
|
||||||
// 1 in such cases.
|
unscheduled_use_count_[buffer] = 0;
|
||||||
unscheduled_use_count_[value] =
|
}
|
||||||
alias_analysis.ValueLivesOut(*value) ? 1 : 0;
|
|
||||||
}
|
}
|
||||||
for (auto* instruction : computation->instructions()) {
|
for (auto* instruction : computation->instructions()) {
|
||||||
for (const HloValue* value : value_uses_.at(instruction)) {
|
for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
|
||||||
++unscheduled_use_count_[value];
|
++unscheduled_use_count_[buffer];
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether the memory used by the given HloValue should be ignored by
|
// Buffers live out of the computation have an implicit use at the end of
|
||||||
|
// the computation.
|
||||||
|
for (const LogicalBuffer* live_out_buffer :
|
||||||
|
points_to_analysis.GetPointsToSet(computation->root_instruction())
|
||||||
|
.CreateFlattenedSet()) {
|
||||||
|
++unscheduled_use_count_[live_out_buffer];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns whether the memory used by the given buffer should be ignored by
|
||||||
// the scheduling heuristic.
|
// the scheduling heuristic.
|
||||||
static bool IgnoreValue(const HloValue& value) {
|
static bool IgnoreBuffer(const LogicalBuffer& buffer) {
|
||||||
return IgnoreInstruction(*value.instruction());
|
return IgnoreInstruction(*buffer.instruction());
|
||||||
}
|
}
|
||||||
|
|
||||||
// An entry in the worklist used by CreateSchedule. Corresponds to one
|
// An entry in the worklist used by CreateSchedule. Corresponds to one
|
||||||
|
|
@ -180,7 +181,7 @@ class ListScheduler {
|
||||||
// U is the number of uses of B that have not yet been scheduled. This pair
|
// U is the number of uses of B that have not yet been scheduled. This pair
|
||||||
// is a pointer into the unscheduled_use_count_ map, so it gets updated for
|
// is a pointer into the unscheduled_use_count_ map, so it gets updated for
|
||||||
// free when we update counts in the map.
|
// free when we update counts in the map.
|
||||||
std::vector<const std::pair<const HloValue* const, int64_t>*>
|
std::vector<const std::pair<const LogicalBuffer* const, int64_t>*>
|
||||||
used_buffer_unscheduled_use_counts;
|
used_buffer_unscheduled_use_counts;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -190,20 +191,18 @@ class ListScheduler {
|
||||||
entry.instruction = instruction;
|
entry.instruction = instruction;
|
||||||
|
|
||||||
entry.bytes_defined = 0;
|
entry.bytes_defined = 0;
|
||||||
HloValueSet value_set =
|
for (auto* buffer :
|
||||||
alias_analysis_.dataflow_analysis().GetFlattenedValueSet(instruction);
|
points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
|
||||||
for (const HloValue* value : value_set.values()) {
|
if (!IgnoreBuffer(*buffer)) {
|
||||||
if (!IgnoreInstruction(*instruction) &&
|
entry.bytes_defined += size_function_(*buffer);
|
||||||
value->instruction() == instruction) {
|
|
||||||
entry.bytes_defined += size_function_(*value);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto* value : value_uses_.at(instruction)) {
|
for (auto* buffer : buffer_uses_.at(instruction)) {
|
||||||
if (IgnoreValue(*value)) {
|
if (IgnoreBuffer(*buffer)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto unscheduled_use_count_it = unscheduled_use_count_.find(value);
|
auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
|
||||||
CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
|
CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
|
||||||
entry.used_buffer_unscheduled_use_counts.push_back(
|
entry.used_buffer_unscheduled_use_counts.push_back(
|
||||||
&*unscheduled_use_count_it);
|
&*unscheduled_use_count_it);
|
||||||
|
|
@ -309,9 +308,9 @@ class ListScheduler {
|
||||||
scheduled_instructions_.insert(best);
|
scheduled_instructions_.insert(best);
|
||||||
|
|
||||||
bool adjust_ready_queue = false;
|
bool adjust_ready_queue = false;
|
||||||
// Update the unscheduled uses of the HloValues.
|
// Update the unscheduled uses of the logical buffers.
|
||||||
for (const HloValue* value : value_uses_.at(best)) {
|
for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
|
||||||
int64_t& count = unscheduled_use_count_[value];
|
int64_t& count = unscheduled_use_count_[buffer];
|
||||||
CHECK_GT(count, 0);
|
CHECK_GT(count, 0);
|
||||||
--count;
|
--count;
|
||||||
if (count == 1) {
|
if (count == 1) {
|
||||||
|
|
@ -335,7 +334,7 @@ class ListScheduler {
|
||||||
for (HloInstruction* succ : best->control_successors()) {
|
for (HloInstruction* succ : best->control_successors()) {
|
||||||
update_pred_count(succ);
|
update_pred_count(succ);
|
||||||
}
|
}
|
||||||
// The unscheduled use count for a HloValue has changed to 1, so the
|
// The unscheduled use count for a buffer has changed to 1, so the
|
||||||
// priorities of some ready instructions may go up. We update them in the
|
// priorities of some ready instructions may go up. We update them in the
|
||||||
// ready queue, so that they can appear earlier.
|
// ready queue, so that they can appear earlier.
|
||||||
if (adjust_ready_queue) {
|
if (adjust_ready_queue) {
|
||||||
|
|
@ -361,24 +360,23 @@ class ListScheduler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK_EQ(schedule.size(), computation_->instruction_count())
|
CHECK_EQ(schedule.size(), computation_->instruction_count());
|
||||||
<< "There could be a cycle in the HLO graph";
|
|
||||||
CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
|
CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
|
||||||
|
|
||||||
return schedule;
|
return schedule;
|
||||||
}
|
}
|
||||||
|
|
||||||
HloComputation* computation_;
|
HloComputation* computation_;
|
||||||
const HloAliasAnalysis& alias_analysis_;
|
const TuplePointsToAnalysis& points_to_analysis_;
|
||||||
const BufferValue::SizeFunction& size_function_;
|
const BufferValue::SizeFunction& size_function_;
|
||||||
|
|
||||||
// A map containing the HloValue that each instruction uses.
|
// A map containing the LogicalBuffers that each instruction uses.
|
||||||
absl::flat_hash_map<const HloInstruction*, std::vector<const HloValue*>>
|
absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
|
||||||
value_uses_;
|
buffer_uses_;
|
||||||
|
|
||||||
// A map containing the count of unscheduled HLOs which using a particular
|
// A map containing the count of unscheduled HLOs which using a particular
|
||||||
// HloValue.
|
// LogicalBuffer.
|
||||||
absl::flat_hash_map<const HloValue*, int64_t> unscheduled_use_count_;
|
absl::flat_hash_map<const LogicalBuffer*, int64_t> unscheduled_use_count_;
|
||||||
|
|
||||||
// Set of instructions which have been scheduled.
|
// Set of instructions which have been scheduled.
|
||||||
absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
|
absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
|
||||||
|
|
@ -398,7 +396,8 @@ int64_t SumBufferSizes(const HloInstruction* hlo, const HloValueSet& value_set,
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
|
absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
|
||||||
const HloModule* module, const HloAliasAnalysis& alias_analysis,
|
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
|
||||||
|
const HloAliasAnalysis& alias_analysis,
|
||||||
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
||||||
int64_t* peak_memory) const {
|
int64_t* peak_memory) const {
|
||||||
HloSchedule schedule(module);
|
HloSchedule schedule(module);
|
||||||
|
|
@ -406,7 +405,7 @@ absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
|
||||||
module->MakeComputationPostOrder(execution_threads)) {
|
module->MakeComputationPostOrder(execution_threads)) {
|
||||||
if (!computation->IsFusionComputation()) {
|
if (!computation->IsFusionComputation()) {
|
||||||
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
|
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
|
||||||
Run(computation, alias_analysis));
|
Run(computation, points_to_analysis, alias_analysis));
|
||||||
if (postprocessor_) {
|
if (postprocessor_) {
|
||||||
computation_sequence = postprocessor_(computation_sequence);
|
computation_sequence = postprocessor_(computation_sequence);
|
||||||
}
|
}
|
||||||
|
|
@ -423,6 +422,7 @@ absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
|
||||||
|
|
||||||
absl::StatusOr<HloInstructionSequence> DFSMemoryScheduler::Run(
|
absl::StatusOr<HloInstructionSequence> DFSMemoryScheduler::Run(
|
||||||
HloComputation* computation,
|
HloComputation* computation,
|
||||||
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const HloAliasAnalysis& alias_analysis) const {
|
const HloAliasAnalysis& alias_analysis) const {
|
||||||
// These variables are a hack to prevent overflows.
|
// These variables are a hack to prevent overflows.
|
||||||
int64_t cumulative_total_size = 0;
|
int64_t cumulative_total_size = 0;
|
||||||
|
|
@ -501,6 +501,7 @@ absl::StatusOr<HloInstructionSequence> DFSMemoryScheduler::Run(
|
||||||
|
|
||||||
absl::StatusOr<HloInstructionSequence> BFScheduler::Run(
|
absl::StatusOr<HloInstructionSequence> BFScheduler::Run(
|
||||||
HloComputation* computation,
|
HloComputation* computation,
|
||||||
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const HloAliasAnalysis& alias_analysis) const {
|
const HloAliasAnalysis& alias_analysis) const {
|
||||||
// Index of HloInstruction in the `computation`.
|
// Index of HloInstruction in the `computation`.
|
||||||
absl::flat_hash_map<const HloInstruction*, int64_t> inst_index;
|
absl::flat_hash_map<const HloInstruction*, int64_t> inst_index;
|
||||||
|
|
@ -554,18 +555,21 @@ absl::StatusOr<HloInstructionSequence> BFScheduler::Run(
|
||||||
|
|
||||||
absl::StatusOr<HloInstructionSequence> ListMemoryScheduler::Run(
|
absl::StatusOr<HloInstructionSequence> ListMemoryScheduler::Run(
|
||||||
HloComputation* computation,
|
HloComputation* computation,
|
||||||
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const HloAliasAnalysis& alias_analysis) const {
|
const HloAliasAnalysis& alias_analysis) const {
|
||||||
return ListScheduler::Run(computation, alias_analysis, size_function_);
|
return ListScheduler::Run(computation, points_to_analysis, size_function_);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<HloInstructionSequence> PostOrderScheduler::Run(
|
absl::StatusOr<HloInstructionSequence> PostOrderScheduler::Run(
|
||||||
HloComputation* computation,
|
HloComputation* computation,
|
||||||
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const HloAliasAnalysis& alias_analysis) const {
|
const HloAliasAnalysis& alias_analysis) const {
|
||||||
return HloInstructionSequence(computation->MakeInstructionPostOrder());
|
return HloInstructionSequence(computation->MakeInstructionPostOrder());
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<HloSchedule> DefaultMemoryScheduler::Run(
|
absl::StatusOr<HloSchedule> DefaultMemoryScheduler::Run(
|
||||||
const HloModule* module, const HloAliasAnalysis& alias_analysis,
|
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
|
||||||
|
const HloAliasAnalysis& alias_analysis,
|
||||||
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
||||||
int64_t* peak_memory) const {
|
int64_t* peak_memory) const {
|
||||||
// We try a few schedulers and choose whichever returns a lower min-memory,
|
// We try a few schedulers and choose whichever returns a lower min-memory,
|
||||||
|
|
@ -577,22 +581,24 @@ absl::StatusOr<HloSchedule> DefaultMemoryScheduler::Run(
|
||||||
// List wins for most of our benchmarks; postorder-based schedulers win for
|
// List wins for most of our benchmarks; postorder-based schedulers win for
|
||||||
// some RNNs.
|
// some RNNs.
|
||||||
int64_t list_memory;
|
int64_t list_memory;
|
||||||
TF_ASSIGN_OR_RETURN(HloSchedule list_sequence,
|
TF_ASSIGN_OR_RETURN(
|
||||||
list_scheduler_.Run(module, alias_analysis,
|
HloSchedule list_sequence,
|
||||||
|
list_scheduler_.Run(module, points_to_analysis, alias_analysis,
|
||||||
execution_threads, &list_memory));
|
execution_threads, &list_memory));
|
||||||
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
|
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
|
||||||
|
|
||||||
int64_t dfs_memory;
|
int64_t dfs_memory;
|
||||||
TF_ASSIGN_OR_RETURN(HloSchedule dfs_sequence,
|
TF_ASSIGN_OR_RETURN(
|
||||||
dfs_scheduler_.Run(module, alias_analysis,
|
HloSchedule dfs_sequence,
|
||||||
|
dfs_scheduler_.Run(module, points_to_analysis, alias_analysis,
|
||||||
execution_threads, &dfs_memory));
|
execution_threads, &dfs_memory));
|
||||||
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
|
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
|
||||||
|
|
||||||
int64_t post_order_memory;
|
int64_t post_order_memory;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HloSchedule post_order_sequence,
|
HloSchedule post_order_sequence,
|
||||||
post_order_scheduler_.Run(module, alias_analysis, execution_threads,
|
post_order_scheduler_.Run(module, points_to_analysis, alias_analysis,
|
||||||
&post_order_memory));
|
execution_threads, &post_order_memory));
|
||||||
VLOG(2) << "Min-memory post order sequence: "
|
VLOG(2) << "Min-memory post order sequence: "
|
||||||
<< HumanReadableNumBytes(post_order_memory);
|
<< HumanReadableNumBytes(post_order_memory);
|
||||||
|
|
||||||
|
|
@ -624,12 +630,15 @@ absl::StatusOr<HloSchedule> ScheduleModule(
|
||||||
return absl::StrFormat("XlaMemoryScheduler:#module=%s,program_id=%d#",
|
return absl::StrFormat("XlaMemoryScheduler:#module=%s,program_id=%d#",
|
||||||
module->name(), module->unique_id());
|
module->name(), module->unique_id());
|
||||||
});
|
});
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||||
|
TuplePointsToAnalysis::Run(module));
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||||
HloAliasAnalysis::Run(module, algorithm.alias_info()));
|
HloAliasAnalysis::Run(module, algorithm.alias_info()));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HloSchedule schedule,
|
HloSchedule schedule,
|
||||||
algorithm.Run(module, *alias_analysis, execution_threads, peak_memory));
|
algorithm.Run(module, *points_to_analysis, *alias_analysis,
|
||||||
|
execution_threads, peak_memory));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(schedule.Verify());
|
TF_RETURN_IF_ERROR(schedule.Verify());
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "xla/hlo/analysis/alias_info.h"
|
#include "xla/hlo/analysis/alias_info.h"
|
||||||
#include "xla/hlo/analysis/hlo_alias_analysis.h"
|
#include "xla/hlo/analysis/hlo_alias_analysis.h"
|
||||||
|
#include "xla/hlo/analysis/tuple_points_to_analysis.h"
|
||||||
#include "xla/hlo/ir/hlo_instruction.h"
|
#include "xla/hlo/ir/hlo_instruction.h"
|
||||||
#include "xla/hlo/ir/hlo_module.h"
|
#include "xla/hlo/ir/hlo_module.h"
|
||||||
#include "xla/hlo/ir/hlo_schedule.h"
|
#include "xla/hlo/ir/hlo_schedule.h"
|
||||||
|
|
@ -38,13 +39,16 @@ namespace xla {
|
||||||
// 'module' given a points-to analysis result that describes buffer aliasing.
|
// 'module' given a points-to analysis result that describes buffer aliasing.
|
||||||
// peak_memory (may be nullptr) is set to the peak memory of the resulting
|
// peak_memory (may be nullptr) is set to the peak memory of the resulting
|
||||||
// schedule according to the HeapSimulator.
|
// schedule according to the HeapSimulator.
|
||||||
|
//
|
||||||
|
// TODO(yunxing): Cleanup usage of TuplePointsToAnalysis.
|
||||||
class ModuleSchedulerAlgorithm {
|
class ModuleSchedulerAlgorithm {
|
||||||
public:
|
public:
|
||||||
explicit ModuleSchedulerAlgorithm(const AliasInfo* alias_info)
|
explicit ModuleSchedulerAlgorithm(const AliasInfo* alias_info)
|
||||||
: alias_info_(alias_info) {}
|
: alias_info_(alias_info) {}
|
||||||
virtual ~ModuleSchedulerAlgorithm() = default;
|
virtual ~ModuleSchedulerAlgorithm() = default;
|
||||||
virtual absl::StatusOr<HloSchedule> Run(
|
virtual absl::StatusOr<HloSchedule> Run(
|
||||||
const HloModule* module, const HloAliasAnalysis& alias_analysis,
|
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
|
||||||
|
const HloAliasAnalysis& alias_analysis,
|
||||||
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
||||||
int64_t* peak_memory) const = 0;
|
int64_t* peak_memory) const = 0;
|
||||||
|
|
||||||
|
|
@ -68,9 +72,11 @@ class ComputationSchedulerAlgorithm : public ModuleSchedulerAlgorithm {
|
||||||
public:
|
public:
|
||||||
virtual absl::StatusOr<HloInstructionSequence> Run(
|
virtual absl::StatusOr<HloInstructionSequence> Run(
|
||||||
HloComputation* computation,
|
HloComputation* computation,
|
||||||
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const HloAliasAnalysis& alias_analysis) const = 0;
|
const HloAliasAnalysis& alias_analysis) const = 0;
|
||||||
absl::StatusOr<HloSchedule> Run(
|
absl::StatusOr<HloSchedule> Run(
|
||||||
const HloModule* module, const HloAliasAnalysis& alias_analysis,
|
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
|
||||||
|
const HloAliasAnalysis& alias_analysis,
|
||||||
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
||||||
int64_t* peak_memory) const override;
|
int64_t* peak_memory) const override;
|
||||||
|
|
||||||
|
|
@ -99,6 +105,7 @@ class ListMemoryScheduler : public ComputationSchedulerAlgorithm {
|
||||||
using ModuleSchedulerAlgorithm::Run;
|
using ModuleSchedulerAlgorithm::Run;
|
||||||
absl::StatusOr<HloInstructionSequence> Run(
|
absl::StatusOr<HloInstructionSequence> Run(
|
||||||
HloComputation* computation,
|
HloComputation* computation,
|
||||||
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const HloAliasAnalysis& alias_analysis) const override;
|
const HloAliasAnalysis& alias_analysis) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -113,6 +120,7 @@ class DFSMemoryScheduler : public ComputationSchedulerAlgorithm {
|
||||||
using ModuleSchedulerAlgorithm::Run;
|
using ModuleSchedulerAlgorithm::Run;
|
||||||
absl::StatusOr<HloInstructionSequence> Run(
|
absl::StatusOr<HloInstructionSequence> Run(
|
||||||
HloComputation* computation,
|
HloComputation* computation,
|
||||||
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const HloAliasAnalysis& alias_analysis) const override;
|
const HloAliasAnalysis& alias_analysis) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -135,6 +143,7 @@ class BFScheduler : public ComputationSchedulerAlgorithm {
|
||||||
std::move(postprocessor)) {}
|
std::move(postprocessor)) {}
|
||||||
absl::StatusOr<HloInstructionSequence> Run(
|
absl::StatusOr<HloInstructionSequence> Run(
|
||||||
HloComputation* computation,
|
HloComputation* computation,
|
||||||
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const HloAliasAnalysis& alias_analysis) const override;
|
const HloAliasAnalysis& alias_analysis) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -149,6 +158,7 @@ class PostOrderScheduler : public ComputationSchedulerAlgorithm {
|
||||||
using ModuleSchedulerAlgorithm::Run;
|
using ModuleSchedulerAlgorithm::Run;
|
||||||
absl::StatusOr<HloInstructionSequence> Run(
|
absl::StatusOr<HloInstructionSequence> Run(
|
||||||
HloComputation* computation,
|
HloComputation* computation,
|
||||||
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const HloAliasAnalysis& alias_analysis) const override;
|
const HloAliasAnalysis& alias_analysis) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -166,7 +176,8 @@ class DefaultMemoryScheduler : public ModuleSchedulerAlgorithm {
|
||||||
dfs_scheduler_(alias_info, size_function, postprocessor),
|
dfs_scheduler_(alias_info, size_function, postprocessor),
|
||||||
post_order_scheduler_(alias_info, size_function, postprocessor) {}
|
post_order_scheduler_(alias_info, size_function, postprocessor) {}
|
||||||
absl::StatusOr<HloSchedule> Run(
|
absl::StatusOr<HloSchedule> Run(
|
||||||
const HloModule* module, const HloAliasAnalysis& alias_analysis,
|
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
|
||||||
|
const HloAliasAnalysis& alias_analysis,
|
||||||
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
||||||
int64_t* peak_memory) const override;
|
int64_t* peak_memory) const override;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user