Adds interface to update target specific states in latency hiding scheduler.

PiperOrigin-RevId: 769511028
This commit is contained in:
A. Unique TensorFlower 2025-06-10 01:36:24 -07:00 committed by TensorFlower Gardener
parent d4068777e7
commit 9e30af5b16
2 changed files with 27 additions and 6 deletions

View File

@ -2150,6 +2150,9 @@ absl::StatusOr<HloGraphNode::TimeCost> DefaultSchedulerCore::ScheduleNode(
sched_state->selective_resource_releasers.push_back(&edge.Target());
}
}
async_tracker_->UpdateTargetDefinedStates(n->GetInstr(),
&sched_state->sched_graph,
latency_estimator_, current_time);
++sched_state->scheduled_count;
for (auto& resource : n->GetResources()) {
if (resource.second == ResourceUsageType::kResourceRelease) {
@ -3109,7 +3112,10 @@ absl::StatusOr<bool> LatencyHidingScheduler::Run(
for (HloComputation* computation : computations_to_schedule_) {
TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> new_schedule,
scheduler_core_->ScheduleComputation(computation));
// Update target specific states that may include altering the computation.
async_tracker_->UpdateTargetDefinedStates(computation);
saved_schedules[computation] = std::move(new_schedule);
async_tracker_->ResetTargetDefinedStates();
}
uint64_t initial_memory_limit = scheduler_core_->GetMemoryLimit();
for (int64_t iter = 0;
@ -3127,7 +3133,9 @@ absl::StatusOr<bool> LatencyHidingScheduler::Run(
for (HloComputation* computation : computations_to_schedule_) {
TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> new_schedule,
scheduler_core_->ScheduleComputation(computation));
async_tracker_->UpdateTargetDefinedStates(computation);
saved_schedules[computation] = std::move(new_schedule);
async_tracker_->ResetTargetDefinedStates();
}
}
LOG(INFO) << "[" << name() << "]"

View File

@ -317,6 +317,18 @@ class AsyncTracker {
return get_canonical_async_op_(hlo);
}
// Updates target defined states after scheduling a node.
virtual void UpdateTargetDefinedStates(
const HloInstruction& hlo, const HloScheduleGraph* schedule_graph,
const LatencyEstimator* latency_estimator,
LatencyEstimator::TimeCost current_time) {}
// Updates target defined states after scheduling a computation.
virtual void UpdateTargetDefinedStates(HloComputation* computation) {}
// Resets target defined states after scheduling a computation.
virtual void ResetTargetDefinedStates() {}
explicit AsyncTracker(
const SchedulerConfig& config,
GetCanonicalAsyncOpFunc func = DefaultGetCanonicalAsyncOp)
@ -1222,8 +1234,9 @@ class DefaultSchedulerCore : public SchedulerCore {
// instructions.
const LatencyEstimator* latency_estimator;
// Class used to track which instructions are async instructions and which
// async instructions computations contain.
const AsyncTracker* async_tracker;
// async instructions computations contain. It also tracks target defined
// states related to the async instructions.
AsyncTracker* async_tracker;
// Tracker of memory pressure for the computation.
MemoryPressureTracker* memory_pressure_tracker;
// Vector containing a list of nodes that aren't ready to schedule yet in
@ -1260,7 +1273,7 @@ class DefaultSchedulerCore : public SchedulerCore {
SchedulingState(const HloInstructionSequence* instr_sequence,
HloAliasAnalysis* alias_analysis,
const LatencyEstimator* latency_estimator,
const AsyncTracker* async_tracker,
AsyncTracker* async_tracker,
MemoryPressureTracker* memory_pressure_tracker,
const SchedulerConfig& config)
: sched_graph(&instr_sequence->instructions(), alias_analysis,
@ -1277,8 +1290,8 @@ class DefaultSchedulerCore : public SchedulerCore {
DefaultSchedulerCore(
HloCostAnalysis::ShapeSizeFunction shape_size_bytes,
const AsyncTracker* async_tracker,
const LatencyEstimator* latency_estimator, const SchedulerConfig& config,
AsyncTracker* async_tracker, const LatencyEstimator* latency_estimator,
const SchedulerConfig& config,
TargetSchedulingRule target_scheduling_rule = nullptr,
TargetSchedulingRule early_target_scheduling_rule = nullptr,
PostProcessingFn post_processing_fn = nullptr,
@ -1366,7 +1379,7 @@ class DefaultSchedulerCore : public SchedulerCore {
HloCostAnalysis::ShapeSizeFunction shape_size_bytes_;
std::unique_ptr<ModulePressureState> module_pressure_state_;
std::unique_ptr<HloAliasAnalysis> alias_analysis_;
const AsyncTracker* async_tracker_;
AsyncTracker* async_tracker_;
const LatencyEstimator* latency_estimator_;
SchedulerConfig config_;
TargetSchedulingRule target_scheduling_rule_ = nullptr;