mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
parent
d991543c54
commit
22032a9edb
|
|
@ -296,11 +296,12 @@ cc_library(
|
|||
"//xla:util",
|
||||
"//xla/hlo/analysis:alias_info",
|
||||
"//xla/hlo/analysis:hlo_alias_analysis",
|
||||
"//xla/hlo/analysis:hlo_dataflow_analysis",
|
||||
"//xla/hlo/analysis:tuple_points_to_analysis",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/hlo/pass:hlo_pass",
|
||||
"//xla/service:buffer_value",
|
||||
"//xla/service:hlo_value",
|
||||
"//xla/service:logical_buffer",
|
||||
"//xla/service/heap_simulator",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:logging",
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ limitations under the License.
|
|||
#include "absl/strings/string_view.h"
|
||||
#include "xla/hlo/analysis/alias_info.h"
|
||||
#include "xla/hlo/analysis/hlo_alias_analysis.h"
|
||||
#include "xla/hlo/analysis/hlo_dataflow_analysis.h"
|
||||
#include "xla/hlo/analysis/tuple_points_to_analysis.h"
|
||||
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
|
||||
#include "xla/hlo/ir/hlo_computation.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
|
|
@ -45,6 +45,7 @@ limitations under the License.
|
|||
#include "xla/service/buffer_value.h"
|
||||
#include "xla/service/heap_simulator/heap_simulator.h"
|
||||
#include "xla/service/hlo_value.h"
|
||||
#include "xla/service/logical_buffer.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/logging.h"
|
||||
|
|
@ -92,9 +93,10 @@ class ListScheduler {
|
|||
// Construct and return a memory-minimizing sequence of HLO instructions
|
||||
// containing the given HLO computation.
|
||||
static absl::StatusOr<HloInstructionSequence> Run(
|
||||
HloComputation* computation, const HloAliasAnalysis& alias_analysis,
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const BufferValue::SizeFunction& size_function) {
|
||||
ListScheduler scheduler(computation, alias_analysis, size_function);
|
||||
ListScheduler scheduler(computation, points_to_analysis, size_function);
|
||||
return scheduler.CreateSchedule();
|
||||
}
|
||||
|
||||
|
|
@ -114,57 +116,56 @@ class ListScheduler {
|
|||
using Priority = std::pair<int64_t, int64_t>;
|
||||
|
||||
ListScheduler(HloComputation* computation,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const BufferValue::SizeFunction& size_function)
|
||||
: computation_(computation),
|
||||
alias_analysis_(alias_analysis),
|
||||
points_to_analysis_(points_to_analysis),
|
||||
size_function_(size_function) {
|
||||
// Create a map containing the HloValue uses for each HLO instruction. An
|
||||
// HLO instruction "uses" a HloValue if the HloValue is in an operand of the
|
||||
// instruction as indicated by HloDataflowAnalysis.
|
||||
const HloDataflowAnalysis& dataflow_analysis =
|
||||
alias_analysis.dataflow_analysis();
|
||||
absl::flat_hash_set<const HloValue*> computation_values;
|
||||
// Create a map containing the LogicalBuffer uses for each HLO
|
||||
// instruction. An HLO instruction "uses" a LogicalBuffer if the
|
||||
// LogicalBuffer is in an operand of the instruction as indicated by
|
||||
// points-to analysis.
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
dataflow_analysis.GetInstructionValueSet(instruction)
|
||||
.ForEachElement(
|
||||
[&](const ShapeIndex& /*index*/, const HloValueSet& value_set) {
|
||||
computation_values.insert(value_set.values().begin(),
|
||||
value_set.values().end());
|
||||
});
|
||||
absl::flat_hash_set<const HloValue*> instr_uses;
|
||||
absl::flat_hash_set<const LogicalBuffer*> instr_uses;
|
||||
for (auto* operand : instruction->operands()) {
|
||||
dataflow_analysis.GetInstructionValueSet(operand).ForEachElement(
|
||||
[&](const ShapeIndex& /*index*/, const HloValueSet& value_set) {
|
||||
instr_uses.insert(value_set.values().begin(),
|
||||
value_set.values().end());
|
||||
points_to_analysis.GetPointsToSet(operand).ForEachElement(
|
||||
[&](const ShapeIndex& /*index*/,
|
||||
const PointsToSet::BufferList& buffers) {
|
||||
instr_uses.insert(buffers.begin(), buffers.end());
|
||||
});
|
||||
}
|
||||
value_uses_[instruction] =
|
||||
std::vector<const HloValue*>(instr_uses.begin(), instr_uses.end());
|
||||
buffer_uses_[instruction] = std::vector<const LogicalBuffer*>(
|
||||
instr_uses.begin(), instr_uses.end());
|
||||
}
|
||||
|
||||
// Create map containing the number of unscheduled uses (hlo instructions)
|
||||
// of each HloValue.
|
||||
unscheduled_use_count_.reserve(computation_values.size());
|
||||
for (const HloValue* value : computation_values) {
|
||||
// HloValues live out of the computation have an implicit use at the end
|
||||
// of the computation. Therefore we initialize `unscheduled_use_count_` to
|
||||
// 1 in such cases.
|
||||
unscheduled_use_count_[value] =
|
||||
alias_analysis.ValueLivesOut(*value) ? 1 : 0;
|
||||
// of each logical buffer.
|
||||
unscheduled_use_count_.reserve(points_to_analysis.num_logical_buffers());
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
for (auto* buffer :
|
||||
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
|
||||
unscheduled_use_count_[buffer] = 0;
|
||||
}
|
||||
}
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
for (const HloValue* value : value_uses_.at(instruction)) {
|
||||
++unscheduled_use_count_[value];
|
||||
for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
|
||||
++unscheduled_use_count_[buffer];
|
||||
}
|
||||
}
|
||||
|
||||
// Buffers live out of the computation have an implicit use at the end of
|
||||
// the computation.
|
||||
for (const LogicalBuffer* live_out_buffer :
|
||||
points_to_analysis.GetPointsToSet(computation->root_instruction())
|
||||
.CreateFlattenedSet()) {
|
||||
++unscheduled_use_count_[live_out_buffer];
|
||||
}
|
||||
}
|
||||
|
||||
// Returns whether the memory used by the given HloValue should be ignored by
|
||||
// Returns whether the memory used by the given buffer should be ignored by
|
||||
// the scheduling heuristic.
|
||||
static bool IgnoreValue(const HloValue& value) {
|
||||
return IgnoreInstruction(*value.instruction());
|
||||
static bool IgnoreBuffer(const LogicalBuffer& buffer) {
|
||||
return IgnoreInstruction(*buffer.instruction());
|
||||
}
|
||||
|
||||
// An entry in the worklist used by CreateSchedule. Corresponds to one
|
||||
|
|
@ -180,7 +181,7 @@ class ListScheduler {
|
|||
// U is the number of uses of B that have not yet been scheduled. This pair
|
||||
// is a pointer into the unscheduled_use_count_ map, so it gets updated for
|
||||
// free when we update counts in the map.
|
||||
std::vector<const std::pair<const HloValue* const, int64_t>*>
|
||||
std::vector<const std::pair<const LogicalBuffer* const, int64_t>*>
|
||||
used_buffer_unscheduled_use_counts;
|
||||
};
|
||||
|
||||
|
|
@ -190,20 +191,18 @@ class ListScheduler {
|
|||
entry.instruction = instruction;
|
||||
|
||||
entry.bytes_defined = 0;
|
||||
HloValueSet value_set =
|
||||
alias_analysis_.dataflow_analysis().GetFlattenedValueSet(instruction);
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
if (!IgnoreInstruction(*instruction) &&
|
||||
value->instruction() == instruction) {
|
||||
entry.bytes_defined += size_function_(*value);
|
||||
for (auto* buffer :
|
||||
points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
|
||||
if (!IgnoreBuffer(*buffer)) {
|
||||
entry.bytes_defined += size_function_(*buffer);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* value : value_uses_.at(instruction)) {
|
||||
if (IgnoreValue(*value)) {
|
||||
for (auto* buffer : buffer_uses_.at(instruction)) {
|
||||
if (IgnoreBuffer(*buffer)) {
|
||||
continue;
|
||||
}
|
||||
auto unscheduled_use_count_it = unscheduled_use_count_.find(value);
|
||||
auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
|
||||
CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
|
||||
entry.used_buffer_unscheduled_use_counts.push_back(
|
||||
&*unscheduled_use_count_it);
|
||||
|
|
@ -309,9 +308,9 @@ class ListScheduler {
|
|||
scheduled_instructions_.insert(best);
|
||||
|
||||
bool adjust_ready_queue = false;
|
||||
// Update the unscheduled uses of the HloValues.
|
||||
for (const HloValue* value : value_uses_.at(best)) {
|
||||
int64_t& count = unscheduled_use_count_[value];
|
||||
// Update the unscheduled uses of the logical buffers.
|
||||
for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
|
||||
int64_t& count = unscheduled_use_count_[buffer];
|
||||
CHECK_GT(count, 0);
|
||||
--count;
|
||||
if (count == 1) {
|
||||
|
|
@ -335,7 +334,7 @@ class ListScheduler {
|
|||
for (HloInstruction* succ : best->control_successors()) {
|
||||
update_pred_count(succ);
|
||||
}
|
||||
// The unscheduled use count for a HloValue has changed to 1, so the
|
||||
// The unscheduled use count for a buffer has changed to 1, so the
|
||||
// priorities of some ready instructions may go up. We update them in the
|
||||
// ready queue, so that they can appear earlier.
|
||||
if (adjust_ready_queue) {
|
||||
|
|
@ -361,24 +360,23 @@ class ListScheduler {
|
|||
}
|
||||
}
|
||||
}
|
||||
CHECK_EQ(schedule.size(), computation_->instruction_count())
|
||||
<< "There could be a cycle in the HLO graph";
|
||||
CHECK_EQ(schedule.size(), computation_->instruction_count());
|
||||
CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
|
||||
|
||||
return schedule;
|
||||
}
|
||||
|
||||
HloComputation* computation_;
|
||||
const HloAliasAnalysis& alias_analysis_;
|
||||
const TuplePointsToAnalysis& points_to_analysis_;
|
||||
const BufferValue::SizeFunction& size_function_;
|
||||
|
||||
// A map containing the HloValue that each instruction uses.
|
||||
absl::flat_hash_map<const HloInstruction*, std::vector<const HloValue*>>
|
||||
value_uses_;
|
||||
// A map containing the LogicalBuffers that each instruction uses.
|
||||
absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
|
||||
buffer_uses_;
|
||||
|
||||
// A map containing the count of unscheduled HLOs which using a particular
|
||||
// HloValue.
|
||||
absl::flat_hash_map<const HloValue*, int64_t> unscheduled_use_count_;
|
||||
// LogicalBuffer.
|
||||
absl::flat_hash_map<const LogicalBuffer*, int64_t> unscheduled_use_count_;
|
||||
|
||||
// Set of instructions which have been scheduled.
|
||||
absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
|
||||
|
|
@ -398,7 +396,8 @@ int64_t SumBufferSizes(const HloInstruction* hlo, const HloValueSet& value_set,
|
|||
} // namespace
|
||||
|
||||
absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
|
||||
const HloModule* module, const HloAliasAnalysis& alias_analysis,
|
||||
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
||||
int64_t* peak_memory) const {
|
||||
HloSchedule schedule(module);
|
||||
|
|
@ -406,7 +405,7 @@ absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
|
|||
module->MakeComputationPostOrder(execution_threads)) {
|
||||
if (!computation->IsFusionComputation()) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
|
||||
Run(computation, alias_analysis));
|
||||
Run(computation, points_to_analysis, alias_analysis));
|
||||
if (postprocessor_) {
|
||||
computation_sequence = postprocessor_(computation_sequence);
|
||||
}
|
||||
|
|
@ -423,6 +422,7 @@ absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
|
|||
|
||||
absl::StatusOr<HloInstructionSequence> DFSMemoryScheduler::Run(
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis) const {
|
||||
// These variables are a hack to prevent overflows.
|
||||
int64_t cumulative_total_size = 0;
|
||||
|
|
@ -501,6 +501,7 @@ absl::StatusOr<HloInstructionSequence> DFSMemoryScheduler::Run(
|
|||
|
||||
absl::StatusOr<HloInstructionSequence> BFScheduler::Run(
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis) const {
|
||||
// Index of HloInstruction in the `computation`.
|
||||
absl::flat_hash_map<const HloInstruction*, int64_t> inst_index;
|
||||
|
|
@ -554,18 +555,21 @@ absl::StatusOr<HloInstructionSequence> BFScheduler::Run(
|
|||
|
||||
absl::StatusOr<HloInstructionSequence> ListMemoryScheduler::Run(
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis) const {
|
||||
return ListScheduler::Run(computation, alias_analysis, size_function_);
|
||||
return ListScheduler::Run(computation, points_to_analysis, size_function_);
|
||||
}
|
||||
|
||||
absl::StatusOr<HloInstructionSequence> PostOrderScheduler::Run(
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis) const {
|
||||
return HloInstructionSequence(computation->MakeInstructionPostOrder());
|
||||
}
|
||||
|
||||
absl::StatusOr<HloSchedule> DefaultMemoryScheduler::Run(
|
||||
const HloModule* module, const HloAliasAnalysis& alias_analysis,
|
||||
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
||||
int64_t* peak_memory) const {
|
||||
// We try a few schedulers and choose whichever returns a lower min-memory,
|
||||
|
|
@ -577,22 +581,24 @@ absl::StatusOr<HloSchedule> DefaultMemoryScheduler::Run(
|
|||
// List wins for most of our benchmarks; postorder-based schedulers win for
|
||||
// some RNNs.
|
||||
int64_t list_memory;
|
||||
TF_ASSIGN_OR_RETURN(HloSchedule list_sequence,
|
||||
list_scheduler_.Run(module, alias_analysis,
|
||||
execution_threads, &list_memory));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloSchedule list_sequence,
|
||||
list_scheduler_.Run(module, points_to_analysis, alias_analysis,
|
||||
execution_threads, &list_memory));
|
||||
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
|
||||
|
||||
int64_t dfs_memory;
|
||||
TF_ASSIGN_OR_RETURN(HloSchedule dfs_sequence,
|
||||
dfs_scheduler_.Run(module, alias_analysis,
|
||||
execution_threads, &dfs_memory));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloSchedule dfs_sequence,
|
||||
dfs_scheduler_.Run(module, points_to_analysis, alias_analysis,
|
||||
execution_threads, &dfs_memory));
|
||||
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
|
||||
|
||||
int64_t post_order_memory;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloSchedule post_order_sequence,
|
||||
post_order_scheduler_.Run(module, alias_analysis, execution_threads,
|
||||
&post_order_memory));
|
||||
post_order_scheduler_.Run(module, points_to_analysis, alias_analysis,
|
||||
execution_threads, &post_order_memory));
|
||||
VLOG(2) << "Min-memory post order sequence: "
|
||||
<< HumanReadableNumBytes(post_order_memory);
|
||||
|
||||
|
|
@ -624,12 +630,15 @@ absl::StatusOr<HloSchedule> ScheduleModule(
|
|||
return absl::StrFormat("XlaMemoryScheduler:#module=%s,program_id=%d#",
|
||||
module->name(), module->unique_id());
|
||||
});
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||
TuplePointsToAnalysis::Run(module));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module, algorithm.alias_info()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloSchedule schedule,
|
||||
algorithm.Run(module, *alias_analysis, execution_threads, peak_memory));
|
||||
algorithm.Run(module, *points_to_analysis, *alias_analysis,
|
||||
execution_threads, peak_memory));
|
||||
|
||||
TF_RETURN_IF_ERROR(schedule.Verify());
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||
#include "absl/strings/string_view.h"
|
||||
#include "xla/hlo/analysis/alias_info.h"
|
||||
#include "xla/hlo/analysis/hlo_alias_analysis.h"
|
||||
#include "xla/hlo/analysis/tuple_points_to_analysis.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
#include "xla/hlo/ir/hlo_schedule.h"
|
||||
|
|
@ -38,13 +39,16 @@ namespace xla {
|
|||
// 'module' given a points-to analysis result that describes buffer aliasing.
|
||||
// peak_memory (may be nullptr) is set to the peak memory of the resulting
|
||||
// schedule according to the HeapSimulator.
|
||||
//
|
||||
// TODO(yunxing): Cleanup usage of TuplePointsToAnalysis.
|
||||
class ModuleSchedulerAlgorithm {
|
||||
public:
|
||||
explicit ModuleSchedulerAlgorithm(const AliasInfo* alias_info)
|
||||
: alias_info_(alias_info) {}
|
||||
virtual ~ModuleSchedulerAlgorithm() = default;
|
||||
virtual absl::StatusOr<HloSchedule> Run(
|
||||
const HloModule* module, const HloAliasAnalysis& alias_analysis,
|
||||
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
||||
int64_t* peak_memory) const = 0;
|
||||
|
||||
|
|
@ -68,9 +72,11 @@ class ComputationSchedulerAlgorithm : public ModuleSchedulerAlgorithm {
|
|||
public:
|
||||
virtual absl::StatusOr<HloInstructionSequence> Run(
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis) const = 0;
|
||||
absl::StatusOr<HloSchedule> Run(
|
||||
const HloModule* module, const HloAliasAnalysis& alias_analysis,
|
||||
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
||||
int64_t* peak_memory) const override;
|
||||
|
||||
|
|
@ -99,6 +105,7 @@ class ListMemoryScheduler : public ComputationSchedulerAlgorithm {
|
|||
using ModuleSchedulerAlgorithm::Run;
|
||||
absl::StatusOr<HloInstructionSequence> Run(
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis) const override;
|
||||
};
|
||||
|
||||
|
|
@ -113,6 +120,7 @@ class DFSMemoryScheduler : public ComputationSchedulerAlgorithm {
|
|||
using ModuleSchedulerAlgorithm::Run;
|
||||
absl::StatusOr<HloInstructionSequence> Run(
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis) const override;
|
||||
};
|
||||
|
||||
|
|
@ -135,6 +143,7 @@ class BFScheduler : public ComputationSchedulerAlgorithm {
|
|||
std::move(postprocessor)) {}
|
||||
absl::StatusOr<HloInstructionSequence> Run(
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis) const override;
|
||||
};
|
||||
|
||||
|
|
@ -149,6 +158,7 @@ class PostOrderScheduler : public ComputationSchedulerAlgorithm {
|
|||
using ModuleSchedulerAlgorithm::Run;
|
||||
absl::StatusOr<HloInstructionSequence> Run(
|
||||
HloComputation* computation,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis) const override;
|
||||
};
|
||||
|
||||
|
|
@ -166,7 +176,8 @@ class DefaultMemoryScheduler : public ModuleSchedulerAlgorithm {
|
|||
dfs_scheduler_(alias_info, size_function, postprocessor),
|
||||
post_order_scheduler_(alias_info, size_function, postprocessor) {}
|
||||
absl::StatusOr<HloSchedule> Run(
|
||||
const HloModule* module, const HloAliasAnalysis& alias_analysis,
|
||||
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
const absl::flat_hash_set<absl::string_view>& execution_threads,
|
||||
int64_t* peak_memory) const override;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user