From cb56bd8ebdf606cb0b1acf4b71740fc09d736148 Mon Sep 17 00:00:00 2001 From: Wilsin Gosti Date: Fri, 24 Jun 2022 13:41:05 -0700 Subject: [PATCH] [tf.data] Implement a new tf.data Autotune stage-based algorithm based on stage-based analysis. This requires some refactoring of the stage-based timing analysis code. PiperOrigin-RevId: 457085290 --- RELEASE.md | 4 + tensorflow/core/data/root_dataset.cc | 19 +- tensorflow/core/framework/model.cc | 270 +++++++++++++--- tensorflow/core/framework/model.h | 79 +++-- tensorflow/core/framework/model.proto | 1 + tensorflow/core/framework/model_test.cc | 303 ++++++++++++++++-- tensorflow/python/data/ops/options.py | 17 +- ...ata.experimental.-autotune-algorithm.pbtxt | 4 + ...ata.experimental.-autotune-algorithm.pbtxt | 4 + 9 files changed, 610 insertions(+), 91 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 310b84a1492..80739382dd4 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -73,6 +73,10 @@ small increase in memory usage due to buffering. To enable this behavior, set `inject_prefetch=True` in `tf.data.experimental.OptimizationOptions`. + * Added a new value to `tf.data.Options.autotune.autotune_algorithm`: + STAGE_BASED. If the autotune algorithm is set to STAGE_BASED, then it + runs a new algorithm that can get the same performance with lower + CPU/memory usage. * `tf.distribute`: diff --git a/tensorflow/core/data/root_dataset.cc b/tensorflow/core/data/root_dataset.cc index 0af9cf1b706..d131e2f42d1 100644 --- a/tensorflow/core/data/root_dataset.cc +++ b/tensorflow/core/data/root_dataset.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/data/root_dataset.h" +#include #include #include #include @@ -160,11 +161,22 @@ class RootDataset::Iterator : public DatasetIterator { Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { + { + tf_shared_lock l(mu_); + if (model_ != nullptr && end_time_usec_ > 0) { + model_->RecordIteratorGapTime(ctx->env()->NowMicros() - end_time_usec_); + } + } if (dataset()->params_.autotune) { TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx)); } - return input_impl_->GetNext(IteratorContext(CreateParams(ctx)), out_tensors, - end_of_sequence); + TF_RETURN_IF_ERROR(input_impl_->GetNext(IteratorContext(CreateParams(ctx)), + out_tensors, end_of_sequence)); + { + mutex_lock l(mu_); + end_time_usec_ = std::max(ctx->env()->NowMicros(), end_time_usec_); + } + return OkStatus(); } protected: @@ -261,6 +273,9 @@ class RootDataset::Iterator : public DatasetIterator { int64_t threadpool_size_; std::unique_ptr thread_pool_; + // The end time of the previous `GetNextInternal` call. + uint64_t end_time_usec_ TF_GUARDED_BY(mu_) = 0; + // Must be ordered last as its execution may depend on other members. std::unique_ptr input_impl_; }; diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 43d19bf9c16..53cb5aaa136 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "absl/time/clock.h" #include "tensorflow/core/framework/cancellation.h" @@ -25,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/statusor.h" namespace tensorflow { namespace data { @@ -35,6 +37,73 @@ constexpr int64_t Model::kOptimizationPeriodMaxMs; namespace { +// A priority queue that holds stage roots where the top of the priority queue +// is the node with the largest total time. +class ModelTimingPriorityQueue { + public: + explicit ModelTimingPriorityQueue(ModelTiming& model_timing) { + std::vector> stage_roots = + model_timing.GetStageRoots(); + if (stage_roots.empty()) { + return; + } + for (auto& root : stage_roots) { + DCHECK(model_timing.GetTiming(root.get()) != nullptr); + stage_roots_queue_.emplace( + model_timing.GetTiming(root.get())->total_time_nsec, root.get()); + } + } + + // Pops the top item from the queue, i.e. node with the largest total time. + StatusOr> PopSlowestStageRoot() { + if (stage_roots_queue_.empty()) { + return errors::Internal( + "Model timing priority queue is empty during stage-based " + "optimization"); + } + std::pair top_item = stage_roots_queue_.top(); + stage_roots_queue_.pop(); + return top_item; + } + + // Push a node together with its total time onto the queue. + void Push(Node* node, double total_time_nsec) { + stage_roots_queue_.emplace(total_time_nsec, node); + } + + private: + std::priority_queue> stage_roots_queue_; +}; + +// A cache that looks up the `parallelism` parameters of nodes the first time +// they are requested and saves them for subsequent requests. +class NodeParallelismParameters { + public: + NodeParallelismParameters() {} + + // Returns the `parallelism` parameter given a node. + Parameter* Get(const Node* node) { + if (node_parallelism_.contains(node)) { + // Look for the `parallelism` parameter of this node in the cache. + return node_parallelism_.at(node); + } + // Find the `parallelism` parameter of this node and cache it. + Node::ModelParameters parameters = node->CollectNodeTunableParameters(); + Node::ModelParameters::iterator parameter_pair = std::find_if( + parameters.begin(), parameters.end(), + [](const std::pair>& + parameter) { return parameter.second->name == kParallelism; }); + if (parameter_pair == parameters.end()) { + return nullptr; + } + node_parallelism_[node] = parameter_pair->second.get(); + return parameter_pair->second.get(); + } + + private: + absl::flat_hash_map node_parallelism_; +}; + // Returns true if all parameters have reached their max values. bool AreAllParametersMax(const Model::ModelParameters& parameters) { for (const auto& pair : parameters) { @@ -177,6 +246,10 @@ Status ModelToProtoHelper(std::shared_ptr output, ModelProto* model) { // Recursively produces node tree rooted in `output` from the given model proto. Status ModelFromProtoHelper(ModelProto model, std::shared_ptr* output) { + if (model.nodes().empty()) { + return errors::Internal( + "Cannot restore model from proto because it has no nodes."); + } TF_RETURN_IF_ERROR(Node::FromProto(model.nodes().at(model.output()), /*output=*/nullptr, output)); std::list> to_restore_inputs = {*output}; @@ -1386,6 +1459,13 @@ Node::ModelParameters Node::CollectTunableParameters() const { return CollectTunableParametersLocked(); } +Node::ModelParameters Node::CollectNodeTunableParameters() const { + tf_shared_lock l(mu_); + Node::ModelParameters parameters; + CollectTunableParametersHelper(¶meters); + return parameters; +} + string Node::DebugString() const { absl::flat_hash_map debug_strings; tf_shared_lock l(mu_); @@ -1943,6 +2023,9 @@ void Model::Optimize(AutotuneAlgorithm algorithm, int64_t cpu_budget, OptimizeGradientDescent(snapshot, optimization_params, cancellation_manager); break; + case AutotuneAlgorithm::STAGE_BASED: + OptimizeStageBased(snapshot, optimization_params, cancellation_manager); + break; default: VLOG(2) << "Autotuning algorithm was not recognized. Aborting " "optimization."; @@ -2022,7 +2105,15 @@ Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, int64_t cpu_budget, } int64_t start_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros; - Optimize(algorithm, cpu_budget, ram_budget, /*model_input_time=*/0, + double model_input_time = 0.0; + // Model input time is set to 0 for all optimization algorithms except for + // stage-based optimization algorithm for historical reason. In stage-based + // optimization algorithm, the model input time is used as a target + // optimization time of all stages in the pipeline. + if (algorithm == AutotuneAlgorithm::STAGE_BASED) { + model_input_time = ComputeTargetTimeNsec(); + } + Optimize(algorithm, cpu_budget, ram_budget, model_input_time, cancellation_manager); int64_t end_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros; VLOG(2) << "Optimized for " << end_ms - start_ms << " ms."; @@ -2167,6 +2258,82 @@ void Model::OptimizeHillClimbHelper( UpdateStateValues(¶meters); } +double Model::ComputeTargetTimeNsec() { + tf_shared_lock l(gap_mu_); + if (gap_time_count_ == 0) { + return 0.0; + } + return (static_cast(gap_time_sum_usec_) / + static_cast(gap_time_count_)) * + 1.0e6; +} + +void Model::OptimizeStageBased(std::shared_ptr snapshot, + const OptimizationParams& optimization_params, + CancellationManager* cancellation_manager) { + return OptimizeStageBasedParallelism( + snapshot, optimization_params.model_input_time(), optimization_params, + cancellation_manager); +} + +void Model::OptimizeStageBasedParallelism( + std::shared_ptr snapshot, double target_time_nsec, + const OptimizationParams& optimization_params, + CancellationManager* cancellation_manager) { + VLOG(2) << "Starting optimization of tunable parameters with Stage-Based " + "optimization with a target time of " + << optimization_params.model_input_time() << " nanoseconds."; + Node::ModelParameters tunable_parameters = CollectTunableParameters(snapshot); + // Initialize the parallelism parameter values to minimal before tuning. + for (std::pair>& pair : + tunable_parameters) { + if (pair.second->name != kParallelism) { + continue; + } + pair.second->value = pair.second->min; + } + ModelTiming model_timing(snapshot); + ModelTimingPriorityQueue priority_queue(model_timing); + StatusOr> critical_root_status = + priority_queue.PopSlowestStageRoot(); + if (!critical_root_status.ok()) { + return; + } + NodeParallelismParameters node_parallelism; + std::pair critical_root = critical_root_status.ValueOrDie(); + while (critical_root.first > target_time_nsec) { + Parameter* parallelism_parameter = + node_parallelism.Get(critical_root.second); + // Stop optimization if the critical stage has no `parallelism` parameter or + // it has reached the max parallelism value. + if (parallelism_parameter == nullptr || + parallelism_parameter->value >= optimization_params.cpu_budget()) { + break; + } + parallelism_parameter->value += 1.0; + if (TotalMaximumBufferedBytes(snapshot) > + optimization_params.ram_budget()) { + // Increasing the parallelism by 1 exceeded ram budget. Reduce it back and + // stop optimization because we cannot improve the most critical stage. + parallelism_parameter->value -= 1.0; + break; + } + // Compute the new total time and put the node back in the queue after its + // parallelism value has been increased by 1. + model_timing.ComputeNodeTotalTime(*critical_root.second); + priority_queue.Push( + critical_root.second, + model_timing.GetTiming(critical_root.second)->total_time_nsec); + // Get the next critical stage root. + critical_root_status = priority_queue.PopSlowestStageRoot(); + if (!critical_root_status.ok()) { + break; + } + critical_root = critical_root_status.ValueOrDie(); + } + UpdateStateValues(&tunable_parameters); +} + void Model::OptimizeHillClimb(std::shared_ptr snapshot, const OptimizationParams& optimization_params, CancellationManager* cancellation_manager) { @@ -2309,9 +2476,18 @@ std::string Model::DebugString() { return cached_debug_string_; } -Node::NodeVector Model::CollectNodes( +ModelTiming::ModelTiming(std::shared_ptr root) : root_(root) { + DCHECK(root_.get() != nullptr); + auto bfs_nodes = CollectNodes(root_, TraversalOrder::BFS, IsAnyNode); + auto reverse_bfs_nodes = bfs_nodes; + std::reverse(reverse_bfs_nodes.begin(), reverse_bfs_nodes.end()); + ComputePipelineRatios(bfs_nodes); + ComputeTotalTimes(reverse_bfs_nodes); +} + +Node::NodeVector ModelTiming::CollectNodes( std::shared_ptr root, TraversalOrder order, - bool collect_node(const std::shared_ptr)) { + bool collect_node(const std::shared_ptr)) const { if (root == nullptr) { return Node::NodeVector({}); } @@ -2327,18 +2503,6 @@ Node::NodeVector Model::CollectNodes( return nodes; } -ModelTiming::ModelTiming(std::shared_ptr model) : model_(model) { - ComputeTiming(); -} - -void ModelTiming::ComputeTiming() { - auto nodes = - model_->CollectNodes(model_->output(), TraversalOrder::BFS, IsAnyNode); - ComputeTimingComponents(nodes); - std::reverse(nodes.begin(), nodes.end()); - ComputeTotalTimes(nodes); -} - const ModelTiming::NodeTiming* ModelTiming::GetTiming(const Node* node) const { if (timing_nodes_.find(node) == timing_nodes_.end()) { return nullptr; @@ -2346,10 +2510,9 @@ const ModelTiming::NodeTiming* ModelTiming::GetTiming(const Node* node) const { return &(timing_nodes_.at(node)); } -void ModelTiming::ComputeTimingComponents(const Node::NodeVector& bfs_nodes) { +void ModelTiming::ComputePipelineRatios(const Node::NodeVector& bfs_nodes) { for (const auto& node : bfs_nodes) { auto& node_timing = timing_nodes_[node.get()]; - node_timing.self_time_nsec = node->ComputeSelfTime(); if (!node->autotune()) { // These are inactive nodes marked by parallel interleave transformations. node_timing.pipeline_ratio = 0.0; @@ -2370,31 +2533,33 @@ void ModelTiming::ComputeTimingComponents(const Node::NodeVector& bfs_nodes) { } } -void ModelTiming::ComputeNodeTotalTime(std::shared_ptr node) { - DCHECK(timing_nodes_.contains(node.get())); - auto& node_timing = timing_nodes_[node.get()]; +void ModelTiming::ComputeNonAsyncInterleaveManyTotalTime(const Node& node) { + DCHECK(timing_nodes_.contains(&node)); + auto& node_timing = timing_nodes_[&node]; double input_total_time_nsec = 0.0; - for (auto input : node->inputs()) { + for (auto input : node.inputs()) { if (input->IsAsync()) { continue; } if (!input->autotune() || input->num_elements() <= 0) { continue; } - DCHECK(timing_nodes_.contains(input.get())); + DCHECK(timing_nodes_.contains(input.get())) + << "Input " << input->long_name() << " of node " << node.long_name() + << " has no timing node."; + input_total_time_nsec += timing_nodes_[input.get()].total_time_nsec; } node_timing.total_time_nsec = - node_timing.self_time_nsec + input_total_time_nsec * node->Ratio(); + node_timing.self_time_nsec + input_total_time_nsec * node.Ratio(); } -void ModelTiming::ComputeAsyncInterleaveManyTotalTime( - std::shared_ptr node) { - DCHECK(timing_nodes_.contains(node)); - auto& node_timing = timing_nodes_[node.get()]; +void ModelTiming::ComputeAsyncInterleaveManyTotalTime(const Node& node) { + DCHECK(timing_nodes_.contains(&node)); + auto& node_timing = timing_nodes_[&node]; double max_input_total_time_nsec = 0.0; double sum_input_throughput = 0.0; - auto inputs = node->inputs(); + auto inputs = node.inputs(); // `ParallelInterleave` is often used to interleave processing of datasets // generated from the first input, e.g. reading from IO where the first input // has the list of all filenames. The first input is typically not the @@ -2409,7 +2574,9 @@ void ModelTiming::ComputeAsyncInterleaveManyTotalTime( if (!(*input)->autotune() || (*input)->num_elements() <= 0) { continue; } - DCHECK(timing_nodes_.contains((*input).get())); + DCHECK(timing_nodes_.contains((*input).get())) + << "Input " << (*input)->long_name() << " of node " << node.long_name() + << " has no timing node."; auto input_total_time_nsec = timing_nodes_[(*input).get()].total_time_nsec; max_input_total_time_nsec = std::max(input_total_time_nsec, max_input_total_time_nsec); @@ -2418,14 +2585,14 @@ void ModelTiming::ComputeAsyncInterleaveManyTotalTime( } } double input_total_time_nsec = 0.0; - auto deterministic = node->ParameterValue(kDeterministic); + auto deterministic = node.ParameterValue(kDeterministic); // After cl/445005635, there should always be `deterministic` parameter for an // ASYNC_INTERLEAVE_MANY node. The "not-ok" check is to allow the code to work // with protos saved and restored before that CL. if (!deterministic.ok() || deterministic.ValueOrDie() == 1.0) { // If deterministic = true, then the total time is `1/worst input // throughput * cycle_length`, or `max input total time / cycle_length`. - input_total_time_nsec = max_input_total_time_nsec * node->Ratio(); + input_total_time_nsec = max_input_total_time_nsec * node.Ratio(); } else if (sum_input_throughput > 0.0) { // If deterministic = false, then the total time is // `1/sum_input_throughput`. @@ -2437,26 +2604,31 @@ void ModelTiming::ComputeAsyncInterleaveManyTotalTime( void ModelTiming::ComputeTotalTimes(const Node::NodeVector& reverse_bfs_nodes) { for (const auto& node : reverse_bfs_nodes) { - if (!node->autotune() || node->num_elements() <= 0) { - continue; - } -#if !defined(IS_MOBILE_PLATFORM) - // This block of code is defined only for non-mobile platform because mobile - // platform lacks RTTI, i.e. the use of `dynamic_cast`. - if (dynamic_cast(node.get()) != nullptr) { - ComputeAsyncInterleaveManyTotalTime(node); - } else { - ComputeNodeTotalTime(node); - } -#else // !IS_MOBILE_PLATFORM - ComputeNodeTotalTime(node); -#endif // !IS_MOBILE_PLATFORM + ComputeNodeTotalTime(*(node.get())); } } +void ModelTiming::ComputeNodeTotalTime(const Node& node) { + NodeTiming& node_timing = timing_nodes_[&node]; + node_timing.self_time_nsec = node.ComputeSelfTime(); + if (!node.autotune() || node.num_elements() <= 0) { + return; + } +#if !defined(IS_MOBILE_PLATFORM) + // This block of code is defined only for non-mobile platform because mobile + // platform lacks RTTI, i.e. the use of `dynamic_cast`. + if (dynamic_cast(&node) != nullptr) { + ComputeAsyncInterleaveManyTotalTime(node); + } else { + ComputeNonAsyncInterleaveManyTotalTime(node); + } +#else // !IS_MOBILE_PLATFORM + ComputeNonAsyncInterleaveManyTotalTime(node); +#endif // !IS_MOBILE_PLATFORM +} + std::vector> ModelTiming::GetStageRoots() const { - auto bfs_nodes = - model_->CollectNodes(model_->output(), TraversalOrder::BFS, IsAnyNode); + auto bfs_nodes = CollectNodes(root_, TraversalOrder::BFS, IsAnyNode); std::vector> roots; if (!bfs_nodes.empty() && !bfs_nodes[0]->IsAsync()) { roots.push_back(bfs_nodes[0]); @@ -2470,8 +2642,8 @@ std::vector> ModelTiming::GetStageRoots() const { } std::vector> ModelTiming::GetStageNodes( - std::shared_ptr root) const { - return model_->CollectNodes(root, TraversalOrder::BFS, IsSyncNode); + std::shared_ptr stage_root) const { + return CollectNodes(stage_root, TraversalOrder::BFS, IsSyncNode); } } // namespace model diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index 8db0b441643..f4b977363b5 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -375,6 +375,9 @@ class Node { // Collects tunable parameters in the subtree rooted in this node. ModelParameters CollectTunableParameters() const TF_LOCKS_EXCLUDED(mu_); + // Collects tunable parameters in this node. + ModelParameters CollectNodeTunableParameters() const TF_LOCKS_EXCLUDED(mu_); + // Returns a human-readable representation of this node. string DebugString() const TF_LOCKS_EXCLUDED(mu_); @@ -689,7 +692,7 @@ class Model { ~Model(); // Returns a pointer to the model's output node. - const std::shared_ptr output() { + const std::shared_ptr output() const { mutex_lock l(mu_); return output_; } @@ -745,9 +748,12 @@ class Model { static Status Load(const string& fname, std::unique_ptr* model, OptimizationParams* optimization_params); - Node::NodeVector CollectNodes(std::shared_ptr root, - TraversalOrder order, - bool collect_node(const std::shared_ptr)); + // Record gap time between consecutive `GetNext()` calls. + void RecordIteratorGapTime(uint64_t duration_usec) { + mutex_lock l(gap_mu_); + gap_time_sum_usec_ += duration_usec; + ++gap_time_count_; + } private: // Determines whether optimization should stop given total processing time, @@ -801,6 +807,27 @@ class Model { const OptimizationParams& optimization_params, CancellationManager* cancellation_manager); + // This optimization starts by setting all tunable parallelism parameters to + // their minimum values. It then repeatedly increases the parallelism + // parameter of the longest stage by 1 until either the longest stage is + // faster than the target time or the memory or CPU budget is fully utilized. + // TODO(b/226910071): The second part of this algorithm optimizes the buffer + // sizes of parallel ops. + void OptimizeStageBased(std::shared_ptr snapshot, + const OptimizationParams& optimization_params, + CancellationManager* cancellation_manager); + + // Computes the target time in nsecs to use for `STAGE_BASED` autotune + // algorithm. + double ComputeTargetTimeNsec(); + + // This is the first part of the stage-based optimization that optimizes + // tunable parallelism parameters. + void OptimizeStageBasedParallelism( + std::shared_ptr snapshot, double target_time_nsec, + const OptimizationParams& optimization_params, + CancellationManager* cancellation_manager); + // Determines if we should stop the gradient descent optimization iterations // based on number of increasable parameters, CPU budget, RAM budget and // current resource usage. @@ -826,7 +853,7 @@ class Model { // Used for coordination between different input pipeline threads. Exclusive // access is required only when adding or removing nodes. Concurrent access to // existing nodes is protected by a node mutex. - mutex mu_; + mutable mutex mu_; // Used for coordinating the optimization loop and model modifications. condition_variable optimize_cond_var_; int64_t id_counter_ TF_GUARDED_BY(mu_) = 1; @@ -845,6 +872,13 @@ class Model { // Cached result of the `DebugString()` invocation used to implement rate // limitting of the computation. std::string cached_debug_string_ = ""; + // Used to coordinate gap time updates between different threads. Gap time is + // the time between the completion of the previous `GetNext()` and the start + // of the next `GetNext()`. + mutable mutex gap_mu_; + // Gap time between consecutive `GetNext()` for a model. + uint64_t gap_time_sum_usec_ TF_GUARDED_BY(gap_mu_) = 0; + uint64_t gap_time_count_ TF_GUARDED_BY(gap_mu_) = 0; }; // Class to compute timing information for a model. @@ -863,7 +897,7 @@ class ModelTiming { double total_time_nsec = 0.0; }; - explicit ModelTiming(std::shared_ptr model); + explicit ModelTiming(std::shared_ptr root); // Returns the timing data for `node`. const NodeTiming* GetTiming(const Node* node) const; @@ -873,28 +907,37 @@ class ModelTiming { // Returns all the nodes of a stage given the stage root. std::vector> GetStageNodes( - std::shared_ptr root) const; + std::shared_ptr stage_root) const; + + // Computes the total time for a node. + void ComputeNodeTotalTime(const Node& node); private: - // Computes timing information for the whole model. - void ComputeTiming(); - - // Computes the pipeline ratio, self time for all nodes. The `bfs_nodes` are - // assumed to be a vector of model nodes in BFS manner. - void ComputeTimingComponents(const Node::NodeVector& bfs_nodes); + // Computes the pipeline ratios of all nodes. + void ComputePipelineRatios(const Node::NodeVector& bfs_nodes); // Computes the total time for all nodes. The `reverse_bfs_nodes` are assumed // to be a vector of model nodes in reversed BFS manner. void ComputeTotalTimes(const Node::NodeVector& reverse_bfs_nodes); - // Computes the total time for a node except when the node is an async - // interleave node. - void ComputeNodeTotalTime(std::shared_ptr node); + // Computes the total time of a node that is not an async interleave node. + void ComputeNonAsyncInterleaveManyTotalTime(const Node& node); // Computes the total time of an async interleave node. - void ComputeAsyncInterleaveManyTotalTime(std::shared_ptr node); + void ComputeAsyncInterleaveManyTotalTime(const Node& node); - std::shared_ptr model_; + // Returns a vector of all nodes in the model. The nodes are either in + // breadth-first search or reverse breadth-first search order depending on the + // `order` argument. The nodes are collected based on the results of the + // `collect_node` predicate: if the predicate returns `false` for a given + // node, then the subtree rooted in this node is excluded. The root node + // itself is not collected. + Node::NodeVector CollectNodes( + std::shared_ptr root, TraversalOrder order, + bool collect_node(const std::shared_ptr)) const; + + // Stores a pointer to the root of a model. + std::shared_ptr root_; // Holds a mapping from node to its timing node. absl::flat_hash_map timing_nodes_; diff --git a/tensorflow/core/framework/model.proto b/tensorflow/core/framework/model.proto index 5bfbf5dfeb1..1e1f4e5b267 100644 --- a/tensorflow/core/framework/model.proto +++ b/tensorflow/core/framework/model.proto @@ -22,6 +22,7 @@ enum AutotuneAlgorithm { HILL_CLIMB = 1; GRADIENT_DESCENT = 2; MAX_PARALLELISM = 3; + STAGE_BASED = 4; } // Protocol buffer representing the data used by the autotuning modeling diff --git a/tensorflow/core/framework/model_test.cc b/tensorflow/core/framework/model_test.cc index acf3dcf6d72..2106996960c 100644 --- a/tensorflow/core/framework/model_test.cc +++ b/tensorflow/core/framework/model_test.cc @@ -1321,22 +1321,38 @@ TEST(RecordTimeTest, RecordTimeTest) { EXPECT_FALSE(source->is_recording()); } +TEST(ModelTest, ModelMetrics) { + CellReader cell_reader("/tensorflow/data/model"); + model::Model model; + std::shared_ptr root = model::MakeUnknownNode({0, "unknown0", nullptr}); + model.AddNode([&root](model::Node::Args args) { return root; }, root->name(), + nullptr, &root); + std::string model_id = strings::StrCat(reinterpret_cast(&model)); + EXPECT_THAT(cell_reader.Read(model_id), + AllOf(HasSubstr("key: 0"), HasSubstr("name: \"unknown0\""), + HasSubstr("autotune: true"))); +} + class ModelTimingTest : public ::testing::Test { public: - // Computes the timing given a Model text proto. - void ComputeModelTiming(const std::string& model_pbtxt) { + // Builds a Model from its text proto. + void BuildModelFromProto(const std::string& model_pbtxt) { ModelProto model_proto; protobuf::TextFormat::ParseFromString(model_pbtxt, &model_proto); - std::unique_ptr model; - TF_CHECK_OK(Model::FromProto(model_proto, &model)); - auto nodes = - model->CollectNodes(model->output(), TraversalOrder::BFS, - [](const std::shared_ptr) { return true; }); + TF_CHECK_OK(Model::FromProto(model_proto, &model_)); + auto nodes = model_->output()->CollectNodes( + TraversalOrder::BFS, [](const std::shared_ptr) { return true; }); node_map_.clear(); + node_map_[model_->output()->id()] = model_->output().get(); for (const auto& node : nodes) { node_map_[node->id()] = node.get(); } - model_timing_ = absl::make_unique(std::move(model)); + } + + // Computes the timing given a Model text proto. + void ComputeModelTiming(const std::string& model_pbtxt) { + BuildModelFromProto(model_pbtxt); + model_timing_ = std::make_unique(model_->output()); } // Gets the timing information of a node given its id. @@ -1344,7 +1360,11 @@ class ModelTimingTest : public ::testing::Test { return model_timing_->GetTiming(node_map_.at(node_id)); } + // Gets the node given its id. + const Node* GetNode(int64_t node_id) const { return node_map_.at(node_id); } + protected: + std::unique_ptr model_; std::unique_ptr model_timing_; absl::flat_hash_map node_map_; }; @@ -1883,16 +1903,263 @@ TEST_F(ModelTimingTest, ParallelInterleave_Batch_ParallelMap) { EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/7)->total_time_nsec); } -TEST(ModelTest, ModelMetrics) { - CellReader cell_reader("/tensorflow/data/model"); - model::Model model; - std::shared_ptr root = model::MakeUnknownNode({0, "unknown0", nullptr}); - model.AddNode([&root](model::Node::Args args) { return root; }, root->name(), - nullptr, &root); - std::string model_id = strings::StrCat(reinterpret_cast(&model)); - EXPECT_THAT(cell_reader.Read(model_id), - AllOf(HasSubstr("key: 0"), HasSubstr("name: \"unknown0\""), - HasSubstr("autotune: true"))); +TEST_F(ModelTimingTest, OptimizeGreedy_OneStage) { + BuildModelFromProto(R"pb( + nodes: { + key: 1 + value: { + id: 1 + name: "ParallelMapV2" + autotune: true + num_elements: 100 + processing_time: 5000 + bytes_produced: 10000 + node_class: ASYNC_KNOWN_RATIO + ratio: 1 + inputs: 2 + parameters: { + name: "parallelism" + value: 4 + min: 1 + max: 16 + tunable: true + } + } + } + nodes: { + key: 2 + value: { + id: 2 + name: "Map" + autotune: true + num_elements: 100 + processing_time: 3000 + node_class: KNOWN_RATIO + ratio: 1 + inputs: 3 + } + } + nodes: { + key: 3 + value: { + id: 3 + name: "SSTable" + autotune: true + num_elements: 100 + processing_time: 1000 + node_class: KNOWN_RATIO + ratio: 2 + } + } + output: 1 + )pb"); + + CancellationManager cancellation_manager; + model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 5, 1000, 50, + &cancellation_manager); + + EXPECT_EQ(5, GetNode(/*node_id=*/1)->parameter_value("parallelism")); +} + +TEST_F(ModelTimingTest, OptimizeGreedy_TwoStages) { + BuildModelFromProto(R"pb( + nodes: { + key: 1 + value: { + id: 1 + name: "ParallelMapV2" + autotune: true + num_elements: 100 + processing_time: 25000 + bytes_produced: 10000 + node_class: ASYNC_KNOWN_RATIO + ratio: 1 + inputs: 2 + parameters: { + name: "parallelism" + value: 4 + min: 1 + max: 16 + tunable: true + } + } + } + nodes: { + key: 2 + value: { + id: 2 + name: "ParallelMapV2" + autotune: true + num_elements: 100 + processing_time: 20000 + bytes_produced: 10000 + node_class: ASYNC_KNOWN_RATIO + ratio: 1 + inputs: 3 + parameters: { + name: "parallelism" + value: 4 + min: 1 + max: 16 + tunable: true + } + } + } + nodes: { + key: 3 + value: { + id: 3 + name: "SSTable" + autotune: true + num_elements: 100 + processing_time: 1000 + node_class: KNOWN_RATIO + ratio: 2 + } + } + output: 1 + )pb"); + + CancellationManager cancellation_manager; + model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 5, 1000, 50, + &cancellation_manager); + + EXPECT_EQ(5, GetNode(/*node_id=*/1)->parameter_value("parallelism")); + EXPECT_EQ(5, GetNode(/*node_id=*/2)->parameter_value("parallelism")); +} + +TEST_F(ModelTimingTest, OptimizeGreedy_TwoStages_RamBudgetExceeded) { + BuildModelFromProto(R"pb( + nodes: { + key: 1 + value: { + id: 1 + name: "ParallelMapV2" + autotune: true + num_elements: 100 + processing_time: 25000 + bytes_produced: 10000 + node_class: ASYNC_KNOWN_RATIO + ratio: 1 + inputs: 2 + parameters: { + name: "parallelism" + value: 4 + min: 1 + max: 16 + tunable: true + } + } + } + nodes: { + key: 2 + value: { + id: 2 + name: "ParallelMapV2" + autotune: true + num_elements: 100 + processing_time: 20000 + bytes_produced: 10000 + node_class: ASYNC_KNOWN_RATIO + ratio: 1 + inputs: 3 + parameters: { + name: "parallelism" + value: 4 + min: 1 + max: 16 + tunable: true + } + } + } + nodes: { + key: 3 + value: { + id: 3 + name: "SSTable" + autotune: true + num_elements: 100 + processing_time: 1000 + node_class: KNOWN_RATIO + ratio: 2 + } + } + output: 1 + )pb"); + + CancellationManager cancellation_manager; + model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 5, 800, 50, + &cancellation_manager); + + EXPECT_EQ(4, GetNode(/*node_id=*/1)->parameter_value("parallelism")); + EXPECT_EQ(4, GetNode(/*node_id=*/2)->parameter_value("parallelism")); +} + +TEST_F(ModelTimingTest, OptimizeGreedy_TwoStages_CpuBudgetExceeded) { + BuildModelFromProto(R"pb( + nodes: { + key: 1 + value: { + id: 1 + name: "ParallelMapV2" + autotune: true + num_elements: 100 + processing_time: 25000 + bytes_produced: 10000 + node_class: ASYNC_KNOWN_RATIO + ratio: 1 + inputs: 2 + parameters: { + name: "parallelism" + value: 4 + min: 1 + max: 16 + tunable: true + } + } + } + nodes: { + key: 2 + value: { + id: 2 + name: "ParallelMapV2" + autotune: true + num_elements: 100 + processing_time: 20000 + bytes_produced: 10000 + node_class: ASYNC_KNOWN_RATIO + ratio: 1 + inputs: 3 + parameters: { + name: "parallelism" + value: 4 + min: 1 + max: 16 + tunable: true + } + } + } + nodes: { + key: 3 + value: { + id: 3 + name: "SSTable" + autotune: true + num_elements: 100 + processing_time: 1000 + node_class: KNOWN_RATIO + ratio: 2 + } + } + output: 1 + )pb"); + + CancellationManager cancellation_manager; + model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 3, 1000, 50, + &cancellation_manager); + + EXPECT_EQ(3, GetNode(/*node_id=*/1)->parameter_value("parallelism")); + EXPECT_EQ(3, GetNode(/*node_id=*/2)->parameter_value("parallelism")); } } // namespace diff --git a/tensorflow/python/data/ops/options.py b/tensorflow/python/data/ops/options.py index 2474f379626..c2fef8b6039 100644 --- a/tensorflow/python/data/ops/options.py +++ b/tensorflow/python/data/ops/options.py @@ -40,11 +40,15 @@ class AutotuneAlgorithm(enum.Enum): MAX_PARALLELISM: Similar to HILL_CLIMB but uses a relaxed stopping condition, allowing the optimization to oversubscribe the CPU. + + STAGE_BASED: In each optimization step, this algorithm chooses the worst + bottleneck parameter and increases its value by 1. """ DEFAULT = 0 HILL_CLIMB = 1 GRADIENT_DESCENT = 2 MAX_PARALLELISM = 3 + STAGE_BASED = 4 @classmethod def _to_proto(cls, obj): @@ -56,9 +60,11 @@ class AutotuneAlgorithm(enum.Enum): return model_pb2.AutotuneAlgorithm.GRADIENT_DESCENT if obj == cls.MAX_PARALLELISM: return model_pb2.AutotuneAlgorithm.MAX_PARALLELISM + if obj == cls.STAGE_BASED: + return model_pb2.AutotuneAlgorithm.STAGE_BASED raise ValueError( - f"Invalid `obj.` Supported values include `DEFAULT`, `HILL_CLIMB` and " - f"`GRADIENT_DESCENT`. Got {obj.name}.") + f"Invalid `obj.` Supported values include `DEFAULT`, `HILL_CLIMB` " + f"`GRADIENT_DESCENT`, and `STAGE_BASED`. Got {obj.name}.") @classmethod def _from_proto(cls, pb): @@ -70,8 +76,11 @@ class AutotuneAlgorithm(enum.Enum): return cls.GRADIENT_DESCENT if pb == model_pb2.AutotuneAlgorithm.MAX_PARALLELISM: return cls.MAX_PARALLELISM - raise ValueError(f"Invalid `pb.` Supported values include `DEFAULT`, " - f"`HILL_CLIMB` and `GRADIENT_DESCENT`. Got {pb}.") + if pb == model_pb2.AutotuneAlgorithm.STAGE_BASED: + return cls.STAGE_BASED + raise ValueError( + f"Invalid `pb.` Supported values include `DEFAULT`, `HILL_CLIMB`, " + f"`GRADIENT_DESCENT` and `STAGE_BASED`. Got {pb}.") @tf_export("data.experimental.AutoShardPolicy") diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-autotune-algorithm.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-autotune-algorithm.pbtxt index 11ba9be8370..ae809d37231 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-autotune-algorithm.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-autotune-algorithm.pbtxt @@ -17,4 +17,8 @@ tf_class { name: "MAX_PARALLELISM" mtype: "" } + member { + name: "STAGE_BASED" + mtype: "" + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-autotune-algorithm.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-autotune-algorithm.pbtxt index 11ba9be8370..ae809d37231 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-autotune-algorithm.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-autotune-algorithm.pbtxt @@ -17,4 +17,8 @@ tf_class { name: "MAX_PARALLELISM" mtype: "" } + member { + name: "STAGE_BASED" + mtype: "" + } }