[tf.data] Implement a new tf.data Autotune stage-based algorithm based on stage-based analysis. This requires some refactoring of the stage-based timing analysis code.

PiperOrigin-RevId: 457085290
This commit is contained in:
Wilsin Gosti 2022-06-24 13:41:05 -07:00 committed by TensorFlower Gardener
parent a197843703
commit cb56bd8ebd
9 changed files with 610 additions and 91 deletions

View File

@ -73,6 +73,10 @@
small increase in memory usage due to buffering. To enable this
behavior, set `inject_prefetch=True` in
`tf.data.experimental.OptimizationOptions`.
* Added a new value to `tf.data.Options.autotune.autotune_algorithm`:
STAGE_BASED. If the autotune algorithm is set to STAGE_BASED, then it
runs a new algorithm that can get the same performance with lower
CPU/memory usage.
* `tf.distribute`:

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/data/root_dataset.h"
#include <algorithm>
#include <functional>
#include <string>
#include <utility>
@ -160,11 +161,22 @@ class RootDataset::Iterator : public DatasetIterator<RootDataset> {
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
{
tf_shared_lock l(mu_);
if (model_ != nullptr && end_time_usec_ > 0) {
model_->RecordIteratorGapTime(ctx->env()->NowMicros() - end_time_usec_);
}
}
if (dataset()->params_.autotune) {
TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx));
}
return input_impl_->GetNext(IteratorContext(CreateParams(ctx)), out_tensors,
end_of_sequence);
TF_RETURN_IF_ERROR(input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
out_tensors, end_of_sequence));
{
mutex_lock l(mu_);
end_time_usec_ = std::max(ctx->env()->NowMicros(), end_time_usec_);
}
return OkStatus();
}
protected:
@ -261,6 +273,9 @@ class RootDataset::Iterator : public DatasetIterator<RootDataset> {
int64_t threadpool_size_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
// The end time of the previous `GetNextInternal` call.
uint64_t end_time_usec_ TF_GUARDED_BY(mu_) = 0;
// Must be ordered last as its execution may depend on other members.
std::unique_ptr<IteratorBase> input_impl_;
};

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
#include <memory>
#include <queue>
#include "absl/time/clock.h"
#include "tensorflow/core/framework/cancellation.h"
@ -25,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/statusor.h"
namespace tensorflow {
namespace data {
@ -35,6 +37,73 @@ constexpr int64_t Model::kOptimizationPeriodMaxMs;
namespace {
// A priority queue that holds stage roots where the top of the priority queue
// is the node with the largest total time.
class ModelTimingPriorityQueue {
public:
explicit ModelTimingPriorityQueue(ModelTiming& model_timing) {
std::vector<std::shared_ptr<Node>> stage_roots =
model_timing.GetStageRoots();
if (stage_roots.empty()) {
return;
}
for (auto& root : stage_roots) {
DCHECK(model_timing.GetTiming(root.get()) != nullptr);
stage_roots_queue_.emplace(
model_timing.GetTiming(root.get())->total_time_nsec, root.get());
}
}
// Pops the top item from the queue, i.e. node with the largest total time.
StatusOr<std::pair<double, Node*>> PopSlowestStageRoot() {
if (stage_roots_queue_.empty()) {
return errors::Internal(
"Model timing priority queue is empty during stage-based "
"optimization");
}
std::pair<double, Node*> top_item = stage_roots_queue_.top();
stage_roots_queue_.pop();
return top_item;
}
// Push a node together with its total time onto the queue.
void Push(Node* node, double total_time_nsec) {
stage_roots_queue_.emplace(total_time_nsec, node);
}
private:
std::priority_queue<std::pair<double, Node*>> stage_roots_queue_;
};
// A cache that looks up the `parallelism` parameters of nodes the first time
// they are requested and saves them for subsequent requests.
class NodeParallelismParameters {
public:
NodeParallelismParameters() {}
// Returns the `parallelism` parameter given a node.
Parameter* Get(const Node* node) {
if (node_parallelism_.contains(node)) {
// Look for the `parallelism` parameter of this node in the cache.
return node_parallelism_.at(node);
}
// Find the `parallelism` parameter of this node and cache it.
Node::ModelParameters parameters = node->CollectNodeTunableParameters();
Node::ModelParameters::iterator parameter_pair = std::find_if(
parameters.begin(), parameters.end(),
[](const std::pair<std::string, std::shared_ptr<Parameter>>&
parameter) { return parameter.second->name == kParallelism; });
if (parameter_pair == parameters.end()) {
return nullptr;
}
node_parallelism_[node] = parameter_pair->second.get();
return parameter_pair->second.get();
}
private:
absl::flat_hash_map<const Node*, Parameter*> node_parallelism_;
};
// Returns true if all parameters have reached their max values.
bool AreAllParametersMax(const Model::ModelParameters& parameters) {
for (const auto& pair : parameters) {
@ -177,6 +246,10 @@ Status ModelToProtoHelper(std::shared_ptr<Node> output, ModelProto* model) {
// Recursively produces node tree rooted in `output` from the given model proto.
Status ModelFromProtoHelper(ModelProto model, std::shared_ptr<Node>* output) {
if (model.nodes().empty()) {
return errors::Internal(
"Cannot restore model from proto because it has no nodes.");
}
TF_RETURN_IF_ERROR(Node::FromProto(model.nodes().at(model.output()),
/*output=*/nullptr, output));
std::list<std::shared_ptr<Node>> to_restore_inputs = {*output};
@ -1386,6 +1459,13 @@ Node::ModelParameters Node::CollectTunableParameters() const {
return CollectTunableParametersLocked();
}
Node::ModelParameters Node::CollectNodeTunableParameters() const {
tf_shared_lock l(mu_);
Node::ModelParameters parameters;
CollectTunableParametersHelper(&parameters);
return parameters;
}
string Node::DebugString() const {
absl::flat_hash_map<string, string> debug_strings;
tf_shared_lock l(mu_);
@ -1943,6 +2023,9 @@ void Model::Optimize(AutotuneAlgorithm algorithm, int64_t cpu_budget,
OptimizeGradientDescent(snapshot, optimization_params,
cancellation_manager);
break;
case AutotuneAlgorithm::STAGE_BASED:
OptimizeStageBased(snapshot, optimization_params, cancellation_manager);
break;
default:
VLOG(2) << "Autotuning algorithm was not recognized. Aborting "
"optimization.";
@ -2022,7 +2105,15 @@ Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, int64_t cpu_budget,
}
int64_t start_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
Optimize(algorithm, cpu_budget, ram_budget, /*model_input_time=*/0,
double model_input_time = 0.0;
// Model input time is set to 0 for all optimization algorithms except for
// stage-based optimization algorithm for historical reason. In stage-based
// optimization algorithm, the model input time is used as a target
// optimization time of all stages in the pipeline.
if (algorithm == AutotuneAlgorithm::STAGE_BASED) {
model_input_time = ComputeTargetTimeNsec();
}
Optimize(algorithm, cpu_budget, ram_budget, model_input_time,
cancellation_manager);
int64_t end_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
VLOG(2) << "Optimized for " << end_ms - start_ms << " ms.";
@ -2167,6 +2258,82 @@ void Model::OptimizeHillClimbHelper(
UpdateStateValues(&parameters);
}
double Model::ComputeTargetTimeNsec() {
tf_shared_lock l(gap_mu_);
if (gap_time_count_ == 0) {
return 0.0;
}
return (static_cast<double>(gap_time_sum_usec_) /
static_cast<double>(gap_time_count_)) *
1.0e6;
}
void Model::OptimizeStageBased(std::shared_ptr<Node> snapshot,
const OptimizationParams& optimization_params,
CancellationManager* cancellation_manager) {
return OptimizeStageBasedParallelism(
snapshot, optimization_params.model_input_time(), optimization_params,
cancellation_manager);
}
void Model::OptimizeStageBasedParallelism(
std::shared_ptr<Node> snapshot, double target_time_nsec,
const OptimizationParams& optimization_params,
CancellationManager* cancellation_manager) {
VLOG(2) << "Starting optimization of tunable parameters with Stage-Based "
"optimization with a target time of "
<< optimization_params.model_input_time() << " nanoseconds.";
Node::ModelParameters tunable_parameters = CollectTunableParameters(snapshot);
// Initialize the parallelism parameter values to minimal before tuning.
for (std::pair<string, std::shared_ptr<Parameter>>& pair :
tunable_parameters) {
if (pair.second->name != kParallelism) {
continue;
}
pair.second->value = pair.second->min;
}
ModelTiming model_timing(snapshot);
ModelTimingPriorityQueue priority_queue(model_timing);
StatusOr<std::pair<double, Node*>> critical_root_status =
priority_queue.PopSlowestStageRoot();
if (!critical_root_status.ok()) {
return;
}
NodeParallelismParameters node_parallelism;
std::pair<double, Node*> critical_root = critical_root_status.ValueOrDie();
while (critical_root.first > target_time_nsec) {
Parameter* parallelism_parameter =
node_parallelism.Get(critical_root.second);
// Stop optimization if the critical stage has no `parallelism` parameter or
// it has reached the max parallelism value.
if (parallelism_parameter == nullptr ||
parallelism_parameter->value >= optimization_params.cpu_budget()) {
break;
}
parallelism_parameter->value += 1.0;
if (TotalMaximumBufferedBytes(snapshot) >
optimization_params.ram_budget()) {
// Increasing the parallelism by 1 exceeded ram budget. Reduce it back and
// stop optimization because we cannot improve the most critical stage.
parallelism_parameter->value -= 1.0;
break;
}
// Compute the new total time and put the node back in the queue after its
// parallelism value has been increased by 1.
model_timing.ComputeNodeTotalTime(*critical_root.second);
priority_queue.Push(
critical_root.second,
model_timing.GetTiming(critical_root.second)->total_time_nsec);
// Get the next critical stage root.
critical_root_status = priority_queue.PopSlowestStageRoot();
if (!critical_root_status.ok()) {
break;
}
critical_root = critical_root_status.ValueOrDie();
}
UpdateStateValues(&tunable_parameters);
}
void Model::OptimizeHillClimb(std::shared_ptr<Node> snapshot,
const OptimizationParams& optimization_params,
CancellationManager* cancellation_manager) {
@ -2309,9 +2476,18 @@ std::string Model::DebugString() {
return cached_debug_string_;
}
Node::NodeVector Model::CollectNodes(
ModelTiming::ModelTiming(std::shared_ptr<Node> root) : root_(root) {
DCHECK(root_.get() != nullptr);
auto bfs_nodes = CollectNodes(root_, TraversalOrder::BFS, IsAnyNode);
auto reverse_bfs_nodes = bfs_nodes;
std::reverse(reverse_bfs_nodes.begin(), reverse_bfs_nodes.end());
ComputePipelineRatios(bfs_nodes);
ComputeTotalTimes(reverse_bfs_nodes);
}
Node::NodeVector ModelTiming::CollectNodes(
std::shared_ptr<Node> root, TraversalOrder order,
bool collect_node(const std::shared_ptr<Node>)) {
bool collect_node(const std::shared_ptr<Node>)) const {
if (root == nullptr) {
return Node::NodeVector({});
}
@ -2327,18 +2503,6 @@ Node::NodeVector Model::CollectNodes(
return nodes;
}
ModelTiming::ModelTiming(std::shared_ptr<Model> model) : model_(model) {
ComputeTiming();
}
void ModelTiming::ComputeTiming() {
auto nodes =
model_->CollectNodes(model_->output(), TraversalOrder::BFS, IsAnyNode);
ComputeTimingComponents(nodes);
std::reverse(nodes.begin(), nodes.end());
ComputeTotalTimes(nodes);
}
const ModelTiming::NodeTiming* ModelTiming::GetTiming(const Node* node) const {
if (timing_nodes_.find(node) == timing_nodes_.end()) {
return nullptr;
@ -2346,10 +2510,9 @@ const ModelTiming::NodeTiming* ModelTiming::GetTiming(const Node* node) const {
return &(timing_nodes_.at(node));
}
void ModelTiming::ComputeTimingComponents(const Node::NodeVector& bfs_nodes) {
void ModelTiming::ComputePipelineRatios(const Node::NodeVector& bfs_nodes) {
for (const auto& node : bfs_nodes) {
auto& node_timing = timing_nodes_[node.get()];
node_timing.self_time_nsec = node->ComputeSelfTime();
if (!node->autotune()) {
// These are inactive nodes marked by parallel interleave transformations.
node_timing.pipeline_ratio = 0.0;
@ -2370,31 +2533,33 @@ void ModelTiming::ComputeTimingComponents(const Node::NodeVector& bfs_nodes) {
}
}
void ModelTiming::ComputeNodeTotalTime(std::shared_ptr<Node> node) {
DCHECK(timing_nodes_.contains(node.get()));
auto& node_timing = timing_nodes_[node.get()];
void ModelTiming::ComputeNonAsyncInterleaveManyTotalTime(const Node& node) {
DCHECK(timing_nodes_.contains(&node));
auto& node_timing = timing_nodes_[&node];
double input_total_time_nsec = 0.0;
for (auto input : node->inputs()) {
for (auto input : node.inputs()) {
if (input->IsAsync()) {
continue;
}
if (!input->autotune() || input->num_elements() <= 0) {
continue;
}
DCHECK(timing_nodes_.contains(input.get()));
DCHECK(timing_nodes_.contains(input.get()))
<< "Input " << input->long_name() << " of node " << node.long_name()
<< " has no timing node.";
input_total_time_nsec += timing_nodes_[input.get()].total_time_nsec;
}
node_timing.total_time_nsec =
node_timing.self_time_nsec + input_total_time_nsec * node->Ratio();
node_timing.self_time_nsec + input_total_time_nsec * node.Ratio();
}
void ModelTiming::ComputeAsyncInterleaveManyTotalTime(
std::shared_ptr<Node> node) {
DCHECK(timing_nodes_.contains(node));
auto& node_timing = timing_nodes_[node.get()];
void ModelTiming::ComputeAsyncInterleaveManyTotalTime(const Node& node) {
DCHECK(timing_nodes_.contains(&node));
auto& node_timing = timing_nodes_[&node];
double max_input_total_time_nsec = 0.0;
double sum_input_throughput = 0.0;
auto inputs = node->inputs();
auto inputs = node.inputs();
// `ParallelInterleave` is often used to interleave processing of datasets
// generated from the first input, e.g. reading from IO where the first input
// has the list of all filenames. The first input is typically not the
@ -2409,7 +2574,9 @@ void ModelTiming::ComputeAsyncInterleaveManyTotalTime(
if (!(*input)->autotune() || (*input)->num_elements() <= 0) {
continue;
}
DCHECK(timing_nodes_.contains((*input).get()));
DCHECK(timing_nodes_.contains((*input).get()))
<< "Input " << (*input)->long_name() << " of node " << node.long_name()
<< " has no timing node.";
auto input_total_time_nsec = timing_nodes_[(*input).get()].total_time_nsec;
max_input_total_time_nsec =
std::max(input_total_time_nsec, max_input_total_time_nsec);
@ -2418,14 +2585,14 @@ void ModelTiming::ComputeAsyncInterleaveManyTotalTime(
}
}
double input_total_time_nsec = 0.0;
auto deterministic = node->ParameterValue(kDeterministic);
auto deterministic = node.ParameterValue(kDeterministic);
// After cl/445005635, there should always be `deterministic` parameter for an
// ASYNC_INTERLEAVE_MANY node. The "not-ok" check is to allow the code to work
// with protos saved and restored before that CL.
if (!deterministic.ok() || deterministic.ValueOrDie() == 1.0) {
// If deterministic = true, then the total time is `1/worst input
// throughput * cycle_length`, or `max input total time / cycle_length`.
input_total_time_nsec = max_input_total_time_nsec * node->Ratio();
input_total_time_nsec = max_input_total_time_nsec * node.Ratio();
} else if (sum_input_throughput > 0.0) {
// If deterministic = false, then the total time is
// `1/sum_input_throughput`.
@ -2437,26 +2604,31 @@ void ModelTiming::ComputeAsyncInterleaveManyTotalTime(
void ModelTiming::ComputeTotalTimes(const Node::NodeVector& reverse_bfs_nodes) {
for (const auto& node : reverse_bfs_nodes) {
if (!node->autotune() || node->num_elements() <= 0) {
continue;
ComputeNodeTotalTime(*(node.get()));
}
}
void ModelTiming::ComputeNodeTotalTime(const Node& node) {
NodeTiming& node_timing = timing_nodes_[&node];
node_timing.self_time_nsec = node.ComputeSelfTime();
if (!node.autotune() || node.num_elements() <= 0) {
return;
}
#if !defined(IS_MOBILE_PLATFORM)
// This block of code is defined only for non-mobile platform because mobile
// platform lacks RTTI, i.e. the use of `dynamic_cast`.
if (dynamic_cast<const AsyncInterleaveMany*>(node.get()) != nullptr) {
if (dynamic_cast<const AsyncInterleaveMany*>(&node) != nullptr) {
ComputeAsyncInterleaveManyTotalTime(node);
} else {
ComputeNodeTotalTime(node);
ComputeNonAsyncInterleaveManyTotalTime(node);
}
#else // !IS_MOBILE_PLATFORM
ComputeNodeTotalTime(node);
ComputeNonAsyncInterleaveManyTotalTime(node);
#endif // !IS_MOBILE_PLATFORM
}
}
std::vector<std::shared_ptr<Node>> ModelTiming::GetStageRoots() const {
auto bfs_nodes =
model_->CollectNodes(model_->output(), TraversalOrder::BFS, IsAnyNode);
auto bfs_nodes = CollectNodes(root_, TraversalOrder::BFS, IsAnyNode);
std::vector<std::shared_ptr<Node>> roots;
if (!bfs_nodes.empty() && !bfs_nodes[0]->IsAsync()) {
roots.push_back(bfs_nodes[0]);
@ -2470,8 +2642,8 @@ std::vector<std::shared_ptr<Node>> ModelTiming::GetStageRoots() const {
}
std::vector<std::shared_ptr<Node>> ModelTiming::GetStageNodes(
std::shared_ptr<Node> root) const {
return model_->CollectNodes(root, TraversalOrder::BFS, IsSyncNode);
std::shared_ptr<Node> stage_root) const {
return CollectNodes(stage_root, TraversalOrder::BFS, IsSyncNode);
}
} // namespace model

View File

@ -375,6 +375,9 @@ class Node {
// Collects tunable parameters in the subtree rooted in this node.
ModelParameters CollectTunableParameters() const TF_LOCKS_EXCLUDED(mu_);
// Collects tunable parameters in this node.
ModelParameters CollectNodeTunableParameters() const TF_LOCKS_EXCLUDED(mu_);
// Returns a human-readable representation of this node.
string DebugString() const TF_LOCKS_EXCLUDED(mu_);
@ -689,7 +692,7 @@ class Model {
~Model();
// Returns a pointer to the model's output node.
const std::shared_ptr<Node> output() {
const std::shared_ptr<Node> output() const {
mutex_lock l(mu_);
return output_;
}
@ -745,9 +748,12 @@ class Model {
static Status Load(const string& fname, std::unique_ptr<Model>* model,
OptimizationParams* optimization_params);
Node::NodeVector CollectNodes(std::shared_ptr<Node> root,
TraversalOrder order,
bool collect_node(const std::shared_ptr<Node>));
// Record gap time between consecutive `GetNext()` calls.
void RecordIteratorGapTime(uint64_t duration_usec) {
mutex_lock l(gap_mu_);
gap_time_sum_usec_ += duration_usec;
++gap_time_count_;
}
private:
// Determines whether optimization should stop given total processing time,
@ -801,6 +807,27 @@ class Model {
const OptimizationParams& optimization_params,
CancellationManager* cancellation_manager);
// This optimization starts by setting all tunable parallelism parameters to
// their minimum values. It then repeatedly increases the parallelism
// parameter of the longest stage by 1 until either the longest stage is
// faster than the target time or the memory or CPU budget is fully utilized.
// TODO(b/226910071): The second part of this algorithm optimizes the buffer
// sizes of parallel ops.
void OptimizeStageBased(std::shared_ptr<Node> snapshot,
const OptimizationParams& optimization_params,
CancellationManager* cancellation_manager);
// Computes the target time in nsecs to use for `STAGE_BASED` autotune
// algorithm.
double ComputeTargetTimeNsec();
// This is the first part of the stage-based optimization that optimizes
// tunable parallelism parameters.
void OptimizeStageBasedParallelism(
std::shared_ptr<Node> snapshot, double target_time_nsec,
const OptimizationParams& optimization_params,
CancellationManager* cancellation_manager);
// Determines if we should stop the gradient descent optimization iterations
// based on number of increasable parameters, CPU budget, RAM budget and
// current resource usage.
@ -826,7 +853,7 @@ class Model {
// Used for coordination between different input pipeline threads. Exclusive
// access is required only when adding or removing nodes. Concurrent access to
// existing nodes is protected by a node mutex.
mutex mu_;
mutable mutex mu_;
// Used for coordinating the optimization loop and model modifications.
condition_variable optimize_cond_var_;
int64_t id_counter_ TF_GUARDED_BY(mu_) = 1;
@ -845,6 +872,13 @@ class Model {
// Cached result of the `DebugString()` invocation used to implement rate
// limitting of the computation.
std::string cached_debug_string_ = "";
// Used to coordinate gap time updates between different threads. Gap time is
// the time between the completion of the previous `GetNext()` and the start
// of the next `GetNext()`.
mutable mutex gap_mu_;
// Gap time between consecutive `GetNext()` for a model.
uint64_t gap_time_sum_usec_ TF_GUARDED_BY(gap_mu_) = 0;
uint64_t gap_time_count_ TF_GUARDED_BY(gap_mu_) = 0;
};
// Class to compute timing information for a model.
@ -863,7 +897,7 @@ class ModelTiming {
double total_time_nsec = 0.0;
};
explicit ModelTiming(std::shared_ptr<Model> model);
explicit ModelTiming(std::shared_ptr<Node> root);
// Returns the timing data for `node`.
const NodeTiming* GetTiming(const Node* node) const;
@ -873,28 +907,37 @@ class ModelTiming {
// Returns all the nodes of a stage given the stage root.
std::vector<std::shared_ptr<Node>> GetStageNodes(
std::shared_ptr<Node> root) const;
std::shared_ptr<Node> stage_root) const;
// Computes the total time for a node.
void ComputeNodeTotalTime(const Node& node);
private:
// Computes timing information for the whole model.
void ComputeTiming();
// Computes the pipeline ratio, self time for all nodes. The `bfs_nodes` are
// assumed to be a vector of model nodes in BFS manner.
void ComputeTimingComponents(const Node::NodeVector& bfs_nodes);
// Computes the pipeline ratios of all nodes.
void ComputePipelineRatios(const Node::NodeVector& bfs_nodes);
// Computes the total time for all nodes. The `reverse_bfs_nodes` are assumed
// to be a vector of model nodes in reversed BFS manner.
void ComputeTotalTimes(const Node::NodeVector& reverse_bfs_nodes);
// Computes the total time for a node except when the node is an async
// interleave node.
void ComputeNodeTotalTime(std::shared_ptr<Node> node);
// Computes the total time of a node that is not an async interleave node.
void ComputeNonAsyncInterleaveManyTotalTime(const Node& node);
// Computes the total time of an async interleave node.
void ComputeAsyncInterleaveManyTotalTime(std::shared_ptr<Node> node);
void ComputeAsyncInterleaveManyTotalTime(const Node& node);
std::shared_ptr<Model> model_;
// Returns a vector of all nodes in the model. The nodes are either in
// breadth-first search or reverse breadth-first search order depending on the
// `order` argument. The nodes are collected based on the results of the
// `collect_node` predicate: if the predicate returns `false` for a given
// node, then the subtree rooted in this node is excluded. The root node
// itself is not collected.
Node::NodeVector CollectNodes(
std::shared_ptr<Node> root, TraversalOrder order,
bool collect_node(const std::shared_ptr<Node>)) const;
// Stores a pointer to the root of a model.
std::shared_ptr<Node> root_;
// Holds a mapping from node to its timing node.
absl::flat_hash_map<const Node*, NodeTiming> timing_nodes_;

View File

@ -22,6 +22,7 @@ enum AutotuneAlgorithm {
HILL_CLIMB = 1;
GRADIENT_DESCENT = 2;
MAX_PARALLELISM = 3;
STAGE_BASED = 4;
}
// Protocol buffer representing the data used by the autotuning modeling

View File

@ -1321,22 +1321,38 @@ TEST(RecordTimeTest, RecordTimeTest) {
EXPECT_FALSE(source->is_recording());
}
TEST(ModelTest, ModelMetrics) {
CellReader<std::string> cell_reader("/tensorflow/data/model");
model::Model model;
std::shared_ptr<Node> root = model::MakeUnknownNode({0, "unknown0", nullptr});
model.AddNode([&root](model::Node::Args args) { return root; }, root->name(),
nullptr, &root);
std::string model_id = strings::StrCat(reinterpret_cast<uint64>(&model));
EXPECT_THAT(cell_reader.Read(model_id),
AllOf(HasSubstr("key: 0"), HasSubstr("name: \"unknown0\""),
HasSubstr("autotune: true")));
}
class ModelTimingTest : public ::testing::Test {
public:
// Computes the timing given a Model text proto.
void ComputeModelTiming(const std::string& model_pbtxt) {
// Builds a Model from its text proto.
void BuildModelFromProto(const std::string& model_pbtxt) {
ModelProto model_proto;
protobuf::TextFormat::ParseFromString(model_pbtxt, &model_proto);
std::unique_ptr<Model> model;
TF_CHECK_OK(Model::FromProto(model_proto, &model));
auto nodes =
model->CollectNodes(model->output(), TraversalOrder::BFS,
[](const std::shared_ptr<Node>) { return true; });
TF_CHECK_OK(Model::FromProto(model_proto, &model_));
auto nodes = model_->output()->CollectNodes(
TraversalOrder::BFS, [](const std::shared_ptr<Node>) { return true; });
node_map_.clear();
node_map_[model_->output()->id()] = model_->output().get();
for (const auto& node : nodes) {
node_map_[node->id()] = node.get();
}
model_timing_ = absl::make_unique<ModelTiming>(std::move(model));
}
// Computes the timing given a Model text proto.
void ComputeModelTiming(const std::string& model_pbtxt) {
BuildModelFromProto(model_pbtxt);
model_timing_ = std::make_unique<ModelTiming>(model_->output());
}
// Gets the timing information of a node given its id.
@ -1344,7 +1360,11 @@ class ModelTimingTest : public ::testing::Test {
return model_timing_->GetTiming(node_map_.at(node_id));
}
// Gets the node given its id.
const Node* GetNode(int64_t node_id) const { return node_map_.at(node_id); }
protected:
std::unique_ptr<Model> model_;
std::unique_ptr<ModelTiming> model_timing_;
absl::flat_hash_map<int64_t, const Node*> node_map_;
};
@ -1883,16 +1903,263 @@ TEST_F(ModelTimingTest, ParallelInterleave_Batch_ParallelMap) {
EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/7)->total_time_nsec);
}
TEST(ModelTest, ModelMetrics) {
CellReader<std::string> cell_reader("/tensorflow/data/model");
model::Model model;
std::shared_ptr<Node> root = model::MakeUnknownNode({0, "unknown0", nullptr});
model.AddNode([&root](model::Node::Args args) { return root; }, root->name(),
nullptr, &root);
std::string model_id = strings::StrCat(reinterpret_cast<uint64>(&model));
EXPECT_THAT(cell_reader.Read(model_id),
AllOf(HasSubstr("key: 0"), HasSubstr("name: \"unknown0\""),
HasSubstr("autotune: true")));
TEST_F(ModelTimingTest, OptimizeGreedy_OneStage) {
BuildModelFromProto(R"pb(
nodes: {
key: 1
value: {
id: 1
name: "ParallelMapV2"
autotune: true
num_elements: 100
processing_time: 5000
bytes_produced: 10000
node_class: ASYNC_KNOWN_RATIO
ratio: 1
inputs: 2
parameters: {
name: "parallelism"
value: 4
min: 1
max: 16
tunable: true
}
}
}
nodes: {
key: 2
value: {
id: 2
name: "Map"
autotune: true
num_elements: 100
processing_time: 3000
node_class: KNOWN_RATIO
ratio: 1
inputs: 3
}
}
nodes: {
key: 3
value: {
id: 3
name: "SSTable"
autotune: true
num_elements: 100
processing_time: 1000
node_class: KNOWN_RATIO
ratio: 2
}
}
output: 1
)pb");
CancellationManager cancellation_manager;
model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 5, 1000, 50,
&cancellation_manager);
EXPECT_EQ(5, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
}
TEST_F(ModelTimingTest, OptimizeGreedy_TwoStages) {
BuildModelFromProto(R"pb(
nodes: {
key: 1
value: {
id: 1
name: "ParallelMapV2"
autotune: true
num_elements: 100
processing_time: 25000
bytes_produced: 10000
node_class: ASYNC_KNOWN_RATIO
ratio: 1
inputs: 2
parameters: {
name: "parallelism"
value: 4
min: 1
max: 16
tunable: true
}
}
}
nodes: {
key: 2
value: {
id: 2
name: "ParallelMapV2"
autotune: true
num_elements: 100
processing_time: 20000
bytes_produced: 10000
node_class: ASYNC_KNOWN_RATIO
ratio: 1
inputs: 3
parameters: {
name: "parallelism"
value: 4
min: 1
max: 16
tunable: true
}
}
}
nodes: {
key: 3
value: {
id: 3
name: "SSTable"
autotune: true
num_elements: 100
processing_time: 1000
node_class: KNOWN_RATIO
ratio: 2
}
}
output: 1
)pb");
CancellationManager cancellation_manager;
model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 5, 1000, 50,
&cancellation_manager);
EXPECT_EQ(5, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
EXPECT_EQ(5, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
}
TEST_F(ModelTimingTest, OptimizeGreedy_TwoStages_RamBudgetExceeded) {
BuildModelFromProto(R"pb(
nodes: {
key: 1
value: {
id: 1
name: "ParallelMapV2"
autotune: true
num_elements: 100
processing_time: 25000
bytes_produced: 10000
node_class: ASYNC_KNOWN_RATIO
ratio: 1
inputs: 2
parameters: {
name: "parallelism"
value: 4
min: 1
max: 16
tunable: true
}
}
}
nodes: {
key: 2
value: {
id: 2
name: "ParallelMapV2"
autotune: true
num_elements: 100
processing_time: 20000
bytes_produced: 10000
node_class: ASYNC_KNOWN_RATIO
ratio: 1
inputs: 3
parameters: {
name: "parallelism"
value: 4
min: 1
max: 16
tunable: true
}
}
}
nodes: {
key: 3
value: {
id: 3
name: "SSTable"
autotune: true
num_elements: 100
processing_time: 1000
node_class: KNOWN_RATIO
ratio: 2
}
}
output: 1
)pb");
CancellationManager cancellation_manager;
model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 5, 800, 50,
&cancellation_manager);
EXPECT_EQ(4, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
EXPECT_EQ(4, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
}
TEST_F(ModelTimingTest, OptimizeGreedy_TwoStages_CpuBudgetExceeded) {
BuildModelFromProto(R"pb(
nodes: {
key: 1
value: {
id: 1
name: "ParallelMapV2"
autotune: true
num_elements: 100
processing_time: 25000
bytes_produced: 10000
node_class: ASYNC_KNOWN_RATIO
ratio: 1
inputs: 2
parameters: {
name: "parallelism"
value: 4
min: 1
max: 16
tunable: true
}
}
}
nodes: {
key: 2
value: {
id: 2
name: "ParallelMapV2"
autotune: true
num_elements: 100
processing_time: 20000
bytes_produced: 10000
node_class: ASYNC_KNOWN_RATIO
ratio: 1
inputs: 3
parameters: {
name: "parallelism"
value: 4
min: 1
max: 16
tunable: true
}
}
}
nodes: {
key: 3
value: {
id: 3
name: "SSTable"
autotune: true
num_elements: 100
processing_time: 1000
node_class: KNOWN_RATIO
ratio: 2
}
}
output: 1
)pb");
CancellationManager cancellation_manager;
model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 3, 1000, 50,
&cancellation_manager);
EXPECT_EQ(3, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
EXPECT_EQ(3, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
}
} // namespace

View File

@ -40,11 +40,15 @@ class AutotuneAlgorithm(enum.Enum):
MAX_PARALLELISM: Similar to HILL_CLIMB but uses a relaxed stopping condition,
allowing the optimization to oversubscribe the CPU.
STAGE_BASED: In each optimization step, this algorithm chooses the worst
bottleneck parameter and increases its value by 1.
"""
DEFAULT = 0
HILL_CLIMB = 1
GRADIENT_DESCENT = 2
MAX_PARALLELISM = 3
STAGE_BASED = 4
@classmethod
def _to_proto(cls, obj):
@ -56,9 +60,11 @@ class AutotuneAlgorithm(enum.Enum):
return model_pb2.AutotuneAlgorithm.GRADIENT_DESCENT
if obj == cls.MAX_PARALLELISM:
return model_pb2.AutotuneAlgorithm.MAX_PARALLELISM
if obj == cls.STAGE_BASED:
return model_pb2.AutotuneAlgorithm.STAGE_BASED
raise ValueError(
f"Invalid `obj.` Supported values include `DEFAULT`, `HILL_CLIMB` and "
f"`GRADIENT_DESCENT`. Got {obj.name}.")
f"Invalid `obj.` Supported values include `DEFAULT`, `HILL_CLIMB` "
f"`GRADIENT_DESCENT`, and `STAGE_BASED`. Got {obj.name}.")
@classmethod
def _from_proto(cls, pb):
@ -70,8 +76,11 @@ class AutotuneAlgorithm(enum.Enum):
return cls.GRADIENT_DESCENT
if pb == model_pb2.AutotuneAlgorithm.MAX_PARALLELISM:
return cls.MAX_PARALLELISM
raise ValueError(f"Invalid `pb.` Supported values include `DEFAULT`, "
f"`HILL_CLIMB` and `GRADIENT_DESCENT`. Got {pb}.")
if pb == model_pb2.AutotuneAlgorithm.STAGE_BASED:
return cls.STAGE_BASED
raise ValueError(
f"Invalid `pb.` Supported values include `DEFAULT`, `HILL_CLIMB`, "
f"`GRADIENT_DESCENT` and `STAGE_BASED`. Got {pb}.")
@tf_export("data.experimental.AutoShardPolicy")

View File

@ -17,4 +17,8 @@ tf_class {
name: "MAX_PARALLELISM"
mtype: "<enum \'AutotuneAlgorithm\'>"
}
member {
name: "STAGE_BASED"
mtype: "<enum \'AutotuneAlgorithm\'>"
}
}

View File

@ -17,4 +17,8 @@ tf_class {
name: "MAX_PARALLELISM"
mtype: "<enum \'AutotuneAlgorithm\'>"
}
member {
name: "STAGE_BASED"
mtype: "<enum \'AutotuneAlgorithm\'>"
}
}