mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Adds interface to update target specific states in latency hiding scheduler.
PiperOrigin-RevId: 769511028
This commit is contained in:
parent
d4068777e7
commit
9e30af5b16
|
|
@ -2150,6 +2150,9 @@ absl::StatusOr<HloGraphNode::TimeCost> DefaultSchedulerCore::ScheduleNode(
|
|||
sched_state->selective_resource_releasers.push_back(&edge.Target());
|
||||
}
|
||||
}
|
||||
async_tracker_->UpdateTargetDefinedStates(n->GetInstr(),
|
||||
&sched_state->sched_graph,
|
||||
latency_estimator_, current_time);
|
||||
++sched_state->scheduled_count;
|
||||
for (auto& resource : n->GetResources()) {
|
||||
if (resource.second == ResourceUsageType::kResourceRelease) {
|
||||
|
|
@ -3109,7 +3112,10 @@ absl::StatusOr<bool> LatencyHidingScheduler::Run(
|
|||
for (HloComputation* computation : computations_to_schedule_) {
|
||||
TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> new_schedule,
|
||||
scheduler_core_->ScheduleComputation(computation));
|
||||
// Update target specific states that may include altering the computation.
|
||||
async_tracker_->UpdateTargetDefinedStates(computation);
|
||||
saved_schedules[computation] = std::move(new_schedule);
|
||||
async_tracker_->ResetTargetDefinedStates();
|
||||
}
|
||||
uint64_t initial_memory_limit = scheduler_core_->GetMemoryLimit();
|
||||
for (int64_t iter = 0;
|
||||
|
|
@ -3127,7 +3133,9 @@ absl::StatusOr<bool> LatencyHidingScheduler::Run(
|
|||
for (HloComputation* computation : computations_to_schedule_) {
|
||||
TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> new_schedule,
|
||||
scheduler_core_->ScheduleComputation(computation));
|
||||
async_tracker_->UpdateTargetDefinedStates(computation);
|
||||
saved_schedules[computation] = std::move(new_schedule);
|
||||
async_tracker_->ResetTargetDefinedStates();
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "[" << name() << "]"
|
||||
|
|
|
|||
|
|
@ -317,6 +317,18 @@ class AsyncTracker {
|
|||
return get_canonical_async_op_(hlo);
|
||||
}
|
||||
|
||||
// Updates target defined states after scheduling a node.
|
||||
virtual void UpdateTargetDefinedStates(
|
||||
const HloInstruction& hlo, const HloScheduleGraph* schedule_graph,
|
||||
const LatencyEstimator* latency_estimator,
|
||||
LatencyEstimator::TimeCost current_time) {}
|
||||
|
||||
// Updates target defined states after scheduling a computation.
|
||||
virtual void UpdateTargetDefinedStates(HloComputation* computation) {}
|
||||
|
||||
// Resets target defined states after scheduling a computation.
|
||||
virtual void ResetTargetDefinedStates() {}
|
||||
|
||||
explicit AsyncTracker(
|
||||
const SchedulerConfig& config,
|
||||
GetCanonicalAsyncOpFunc func = DefaultGetCanonicalAsyncOp)
|
||||
|
|
@ -1222,8 +1234,9 @@ class DefaultSchedulerCore : public SchedulerCore {
|
|||
// instructions.
|
||||
const LatencyEstimator* latency_estimator;
|
||||
// Class used to track which instructions are async instructions and which
|
||||
// async instructions computations contain.
|
||||
const AsyncTracker* async_tracker;
|
||||
// async instructions computations contain. It also tracks target defined
|
||||
// states related to the async instructions.
|
||||
AsyncTracker* async_tracker;
|
||||
// Tracker of memory pressure for the computation.
|
||||
MemoryPressureTracker* memory_pressure_tracker;
|
||||
// Vector containing a list of nodes that aren't ready to schedule yet in
|
||||
|
|
@ -1260,7 +1273,7 @@ class DefaultSchedulerCore : public SchedulerCore {
|
|||
SchedulingState(const HloInstructionSequence* instr_sequence,
|
||||
HloAliasAnalysis* alias_analysis,
|
||||
const LatencyEstimator* latency_estimator,
|
||||
const AsyncTracker* async_tracker,
|
||||
AsyncTracker* async_tracker,
|
||||
MemoryPressureTracker* memory_pressure_tracker,
|
||||
const SchedulerConfig& config)
|
||||
: sched_graph(&instr_sequence->instructions(), alias_analysis,
|
||||
|
|
@ -1277,8 +1290,8 @@ class DefaultSchedulerCore : public SchedulerCore {
|
|||
|
||||
DefaultSchedulerCore(
|
||||
HloCostAnalysis::ShapeSizeFunction shape_size_bytes,
|
||||
const AsyncTracker* async_tracker,
|
||||
const LatencyEstimator* latency_estimator, const SchedulerConfig& config,
|
||||
AsyncTracker* async_tracker, const LatencyEstimator* latency_estimator,
|
||||
const SchedulerConfig& config,
|
||||
TargetSchedulingRule target_scheduling_rule = nullptr,
|
||||
TargetSchedulingRule early_target_scheduling_rule = nullptr,
|
||||
PostProcessingFn post_processing_fn = nullptr,
|
||||
|
|
@ -1366,7 +1379,7 @@ class DefaultSchedulerCore : public SchedulerCore {
|
|||
HloCostAnalysis::ShapeSizeFunction shape_size_bytes_;
|
||||
std::unique_ptr<ModulePressureState> module_pressure_state_;
|
||||
std::unique_ptr<HloAliasAnalysis> alias_analysis_;
|
||||
const AsyncTracker* async_tracker_;
|
||||
AsyncTracker* async_tracker_;
|
||||
const LatencyEstimator* latency_estimator_;
|
||||
SchedulerConfig config_;
|
||||
TargetSchedulingRule target_scheduling_rule_ = nullptr;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user