rollback a broken CL

Reverts 5bddc28fd5

PiperOrigin-RevId: 800789392
This commit is contained in:
A. Unique TensorFlower 2025-08-29 01:07:08 -07:00 committed by TensorFlower Gardener
parent d991543c54
commit 22032a9edb
3 changed files with 97 additions and 76 deletions

View File

@ -296,11 +296,12 @@ cc_library(
"//xla:util",
"//xla/hlo/analysis:alias_info",
"//xla/hlo/analysis:hlo_alias_analysis",
"//xla/hlo/analysis:hlo_dataflow_analysis",
"//xla/hlo/analysis:tuple_points_to_analysis",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/service:buffer_value",
"//xla/service:hlo_value",
"//xla/service:logical_buffer",
"//xla/service/heap_simulator",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",

View File

@ -36,7 +36,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "xla/hlo/analysis/alias_info.h"
#include "xla/hlo/analysis/hlo_alias_analysis.h"
#include "xla/hlo/analysis/hlo_dataflow_analysis.h"
#include "xla/hlo/analysis/tuple_points_to_analysis.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
@ -45,6 +45,7 @@ limitations under the License.
#include "xla/service/buffer_value.h"
#include "xla/service/heap_simulator/heap_simulator.h"
#include "xla/service/hlo_value.h"
#include "xla/service/logical_buffer.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
@ -92,9 +93,10 @@ class ListScheduler {
// Construct and return a memory-minimizing sequence of HLO instructions
// containing the given HLO computation.
static absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation, const HloAliasAnalysis& alias_analysis,
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_function) {
ListScheduler scheduler(computation, alias_analysis, size_function);
ListScheduler scheduler(computation, points_to_analysis, size_function);
return scheduler.CreateSchedule();
}
@ -114,57 +116,56 @@ class ListScheduler {
using Priority = std::pair<int64_t, int64_t>;
ListScheduler(HloComputation* computation,
const HloAliasAnalysis& alias_analysis,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_function)
: computation_(computation),
alias_analysis_(alias_analysis),
points_to_analysis_(points_to_analysis),
size_function_(size_function) {
// Create a map containing the HloValue uses for each HLO instruction. An
// HLO instruction "uses" a HloValue if the HloValue is in an operand of the
// instruction as indicated by HloDataflowAnalysis.
const HloDataflowAnalysis& dataflow_analysis =
alias_analysis.dataflow_analysis();
absl::flat_hash_set<const HloValue*> computation_values;
// Create a map containing the LogicalBuffer uses for each HLO
// instruction. An HLO instruction "uses" a LogicalBuffer if the
// LogicalBuffer is in an operand of the instruction as indicated by
// points-to analysis.
for (auto* instruction : computation->instructions()) {
dataflow_analysis.GetInstructionValueSet(instruction)
.ForEachElement(
[&](const ShapeIndex& /*index*/, const HloValueSet& value_set) {
computation_values.insert(value_set.values().begin(),
value_set.values().end());
});
absl::flat_hash_set<const HloValue*> instr_uses;
absl::flat_hash_set<const LogicalBuffer*> instr_uses;
for (auto* operand : instruction->operands()) {
dataflow_analysis.GetInstructionValueSet(operand).ForEachElement(
[&](const ShapeIndex& /*index*/, const HloValueSet& value_set) {
instr_uses.insert(value_set.values().begin(),
value_set.values().end());
points_to_analysis.GetPointsToSet(operand).ForEachElement(
[&](const ShapeIndex& /*index*/,
const PointsToSet::BufferList& buffers) {
instr_uses.insert(buffers.begin(), buffers.end());
});
}
value_uses_[instruction] =
std::vector<const HloValue*>(instr_uses.begin(), instr_uses.end());
buffer_uses_[instruction] = std::vector<const LogicalBuffer*>(
instr_uses.begin(), instr_uses.end());
}
// Create map containing the number of unscheduled uses (hlo instructions)
// of each HloValue.
unscheduled_use_count_.reserve(computation_values.size());
for (const HloValue* value : computation_values) {
// HloValues live out of the computation have an implicit use at the end
// of the computation. Therefore we initialize `unscheduled_use_count_` to
// 1 in such cases.
unscheduled_use_count_[value] =
alias_analysis.ValueLivesOut(*value) ? 1 : 0;
// of each logical buffer.
unscheduled_use_count_.reserve(points_to_analysis.num_logical_buffers());
for (auto* instruction : computation->instructions()) {
for (auto* buffer :
points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
unscheduled_use_count_[buffer] = 0;
}
}
for (auto* instruction : computation->instructions()) {
for (const HloValue* value : value_uses_.at(instruction)) {
++unscheduled_use_count_[value];
}
for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
++unscheduled_use_count_[buffer];
}
}
// Returns whether the memory used by the given HloValue should be ignored by
// Buffers live out of the computation have an implicit use at the end of
// the computation.
for (const LogicalBuffer* live_out_buffer :
points_to_analysis.GetPointsToSet(computation->root_instruction())
.CreateFlattenedSet()) {
++unscheduled_use_count_[live_out_buffer];
}
}
// Returns whether the memory used by the given buffer should be ignored by
// the scheduling heuristic.
static bool IgnoreValue(const HloValue& value) {
return IgnoreInstruction(*value.instruction());
static bool IgnoreBuffer(const LogicalBuffer& buffer) {
return IgnoreInstruction(*buffer.instruction());
}
// An entry in the worklist used by CreateSchedule. Corresponds to one
@ -180,7 +181,7 @@ class ListScheduler {
// U is the number of uses of B that have not yet been scheduled. This pair
// is a pointer into the unscheduled_use_count_ map, so it gets updated for
// free when we update counts in the map.
std::vector<const std::pair<const HloValue* const, int64_t>*>
std::vector<const std::pair<const LogicalBuffer* const, int64_t>*>
used_buffer_unscheduled_use_counts;
};
@ -190,20 +191,18 @@ class ListScheduler {
entry.instruction = instruction;
entry.bytes_defined = 0;
HloValueSet value_set =
alias_analysis_.dataflow_analysis().GetFlattenedValueSet(instruction);
for (const HloValue* value : value_set.values()) {
if (!IgnoreInstruction(*instruction) &&
value->instruction() == instruction) {
entry.bytes_defined += size_function_(*value);
for (auto* buffer :
points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
if (!IgnoreBuffer(*buffer)) {
entry.bytes_defined += size_function_(*buffer);
}
}
for (auto* value : value_uses_.at(instruction)) {
if (IgnoreValue(*value)) {
for (auto* buffer : buffer_uses_.at(instruction)) {
if (IgnoreBuffer(*buffer)) {
continue;
}
auto unscheduled_use_count_it = unscheduled_use_count_.find(value);
auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
entry.used_buffer_unscheduled_use_counts.push_back(
&*unscheduled_use_count_it);
@ -309,9 +308,9 @@ class ListScheduler {
scheduled_instructions_.insert(best);
bool adjust_ready_queue = false;
// Update the unscheduled uses of the HloValues.
for (const HloValue* value : value_uses_.at(best)) {
int64_t& count = unscheduled_use_count_[value];
// Update the unscheduled uses of the logical buffers.
for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
int64_t& count = unscheduled_use_count_[buffer];
CHECK_GT(count, 0);
--count;
if (count == 1) {
@ -335,7 +334,7 @@ class ListScheduler {
for (HloInstruction* succ : best->control_successors()) {
update_pred_count(succ);
}
// The unscheduled use count for a HloValue has changed to 1, so the
// The unscheduled use count for a buffer has changed to 1, so the
// priorities of some ready instructions may go up. We update them in the
// ready queue, so that they can appear earlier.
if (adjust_ready_queue) {
@ -361,24 +360,23 @@ class ListScheduler {
}
}
}
CHECK_EQ(schedule.size(), computation_->instruction_count())
<< "There could be a cycle in the HLO graph";
CHECK_EQ(schedule.size(), computation_->instruction_count());
CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
return schedule;
}
HloComputation* computation_;
const HloAliasAnalysis& alias_analysis_;
const TuplePointsToAnalysis& points_to_analysis_;
const BufferValue::SizeFunction& size_function_;
// A map containing the HloValue that each instruction uses.
absl::flat_hash_map<const HloInstruction*, std::vector<const HloValue*>>
value_uses_;
// A map containing the LogicalBuffers that each instruction uses.
absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
buffer_uses_;
// A map containing the count of unscheduled HLOs which using a particular
// HloValue.
absl::flat_hash_map<const HloValue*, int64_t> unscheduled_use_count_;
// LogicalBuffer.
absl::flat_hash_map<const LogicalBuffer*, int64_t> unscheduled_use_count_;
// Set of instructions which have been scheduled.
absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
@ -398,7 +396,8 @@ int64_t SumBufferSizes(const HloInstruction* hlo, const HloValueSet& value_set,
} // namespace
absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
const HloModule* module, const HloAliasAnalysis& alias_analysis,
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const absl::flat_hash_set<absl::string_view>& execution_threads,
int64_t* peak_memory) const {
HloSchedule schedule(module);
@ -406,7 +405,7 @@ absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
module->MakeComputationPostOrder(execution_threads)) {
if (!computation->IsFusionComputation()) {
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
Run(computation, alias_analysis));
Run(computation, points_to_analysis, alias_analysis));
if (postprocessor_) {
computation_sequence = postprocessor_(computation_sequence);
}
@ -423,6 +422,7 @@ absl::StatusOr<HloSchedule> ComputationSchedulerAlgorithm::Run(
absl::StatusOr<HloInstructionSequence> DFSMemoryScheduler::Run(
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const {
// These variables are a hack to prevent overflows.
int64_t cumulative_total_size = 0;
@ -501,6 +501,7 @@ absl::StatusOr<HloInstructionSequence> DFSMemoryScheduler::Run(
absl::StatusOr<HloInstructionSequence> BFScheduler::Run(
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const {
// Index of HloInstruction in the `computation`.
absl::flat_hash_map<const HloInstruction*, int64_t> inst_index;
@ -554,18 +555,21 @@ absl::StatusOr<HloInstructionSequence> BFScheduler::Run(
absl::StatusOr<HloInstructionSequence> ListMemoryScheduler::Run(
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const {
return ListScheduler::Run(computation, alias_analysis, size_function_);
return ListScheduler::Run(computation, points_to_analysis, size_function_);
}
absl::StatusOr<HloInstructionSequence> PostOrderScheduler::Run(
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const {
return HloInstructionSequence(computation->MakeInstructionPostOrder());
}
absl::StatusOr<HloSchedule> DefaultMemoryScheduler::Run(
const HloModule* module, const HloAliasAnalysis& alias_analysis,
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const absl::flat_hash_set<absl::string_view>& execution_threads,
int64_t* peak_memory) const {
// We try a few schedulers and choose whichever returns a lower min-memory,
@ -577,22 +581,24 @@ absl::StatusOr<HloSchedule> DefaultMemoryScheduler::Run(
// List wins for most of our benchmarks; postorder-based schedulers win for
// some RNNs.
int64_t list_memory;
TF_ASSIGN_OR_RETURN(HloSchedule list_sequence,
list_scheduler_.Run(module, alias_analysis,
TF_ASSIGN_OR_RETURN(
HloSchedule list_sequence,
list_scheduler_.Run(module, points_to_analysis, alias_analysis,
execution_threads, &list_memory));
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
int64_t dfs_memory;
TF_ASSIGN_OR_RETURN(HloSchedule dfs_sequence,
dfs_scheduler_.Run(module, alias_analysis,
TF_ASSIGN_OR_RETURN(
HloSchedule dfs_sequence,
dfs_scheduler_.Run(module, points_to_analysis, alias_analysis,
execution_threads, &dfs_memory));
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
int64_t post_order_memory;
TF_ASSIGN_OR_RETURN(
HloSchedule post_order_sequence,
post_order_scheduler_.Run(module, alias_analysis, execution_threads,
&post_order_memory));
post_order_scheduler_.Run(module, points_to_analysis, alias_analysis,
execution_threads, &post_order_memory));
VLOG(2) << "Min-memory post order sequence: "
<< HumanReadableNumBytes(post_order_memory);
@ -624,12 +630,15 @@ absl::StatusOr<HloSchedule> ScheduleModule(
return absl::StrFormat("XlaMemoryScheduler:#module=%s,program_id=%d#",
module->name(), module->unique_id());
});
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(module));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, algorithm.alias_info()));
TF_ASSIGN_OR_RETURN(
HloSchedule schedule,
algorithm.Run(module, *alias_analysis, execution_threads, peak_memory));
algorithm.Run(module, *points_to_analysis, *alias_analysis,
execution_threads, peak_memory));
TF_RETURN_IF_ERROR(schedule.Verify());

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "xla/hlo/analysis/alias_info.h"
#include "xla/hlo/analysis/hlo_alias_analysis.h"
#include "xla/hlo/analysis/tuple_points_to_analysis.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_schedule.h"
@ -38,13 +39,16 @@ namespace xla {
// 'module' given a points-to analysis result that describes buffer aliasing.
// peak_memory (may be nullptr) is set to the peak memory of the resulting
// schedule according to the HeapSimulator.
//
// TODO(yunxing): Cleanup usage of TuplePointsToAnalysis.
class ModuleSchedulerAlgorithm {
public:
explicit ModuleSchedulerAlgorithm(const AliasInfo* alias_info)
: alias_info_(alias_info) {}
virtual ~ModuleSchedulerAlgorithm() = default;
virtual absl::StatusOr<HloSchedule> Run(
const HloModule* module, const HloAliasAnalysis& alias_analysis,
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const absl::flat_hash_set<absl::string_view>& execution_threads,
int64_t* peak_memory) const = 0;
@ -68,9 +72,11 @@ class ComputationSchedulerAlgorithm : public ModuleSchedulerAlgorithm {
public:
virtual absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const = 0;
absl::StatusOr<HloSchedule> Run(
const HloModule* module, const HloAliasAnalysis& alias_analysis,
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const absl::flat_hash_set<absl::string_view>& execution_threads,
int64_t* peak_memory) const override;
@ -99,6 +105,7 @@ class ListMemoryScheduler : public ComputationSchedulerAlgorithm {
using ModuleSchedulerAlgorithm::Run;
absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const override;
};
@ -113,6 +120,7 @@ class DFSMemoryScheduler : public ComputationSchedulerAlgorithm {
using ModuleSchedulerAlgorithm::Run;
absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const override;
};
@ -135,6 +143,7 @@ class BFScheduler : public ComputationSchedulerAlgorithm {
std::move(postprocessor)) {}
absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const override;
};
@ -149,6 +158,7 @@ class PostOrderScheduler : public ComputationSchedulerAlgorithm {
using ModuleSchedulerAlgorithm::Run;
absl::StatusOr<HloInstructionSequence> Run(
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis) const override;
};
@ -166,7 +176,8 @@ class DefaultMemoryScheduler : public ModuleSchedulerAlgorithm {
dfs_scheduler_(alias_info, size_function, postprocessor),
post_order_scheduler_(alias_info, size_function, postprocessor) {}
absl::StatusOr<HloSchedule> Run(
const HloModule* module, const HloAliasAnalysis& alias_analysis,
const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const absl::flat_hash_set<absl::string_view>& execution_threads,
int64_t* peak_memory) const override;