mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[tf.data] Implement a new tf.data Autotune stage-based algorithm based on stage-based analysis. This requires some refactoring of the stage-based timing analysis code.
PiperOrigin-RevId: 457085290
This commit is contained in:
parent
a197843703
commit
cb56bd8ebd
|
|
@ -73,6 +73,10 @@
|
|||
small increase in memory usage due to buffering. To enable this
|
||||
behavior, set `inject_prefetch=True` in
|
||||
`tf.data.experimental.OptimizationOptions`.
|
||||
* Added a new value to `tf.data.Options.autotune.autotune_algorithm`:
|
||||
STAGE_BASED. If the autotune algorithm is set to STAGE_BASED, then it
|
||||
runs a new algorithm that can get the same performance with lower
|
||||
CPU/memory usage.
|
||||
|
||||
* `tf.distribute`:
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/core/data/root_dataset.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
|
@ -160,11 +161,22 @@ class RootDataset::Iterator : public DatasetIterator<RootDataset> {
|
|||
|
||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
if (model_ != nullptr && end_time_usec_ > 0) {
|
||||
model_->RecordIteratorGapTime(ctx->env()->NowMicros() - end_time_usec_);
|
||||
}
|
||||
}
|
||||
if (dataset()->params_.autotune) {
|
||||
TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx));
|
||||
}
|
||||
return input_impl_->GetNext(IteratorContext(CreateParams(ctx)), out_tensors,
|
||||
end_of_sequence);
|
||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
|
||||
out_tensors, end_of_sequence));
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
end_time_usec_ = std::max(ctx->env()->NowMicros(), end_time_usec_);
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
@ -261,6 +273,9 @@ class RootDataset::Iterator : public DatasetIterator<RootDataset> {
|
|||
int64_t threadpool_size_;
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
|
||||
// The end time of the previous `GetNextInternal` call.
|
||||
uint64_t end_time_usec_ TF_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// Must be ordered last as its execution may depend on other members.
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include "absl/time/clock.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
|
|
@ -25,6 +26,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
|
@ -35,6 +37,73 @@ constexpr int64_t Model::kOptimizationPeriodMaxMs;
|
|||
|
||||
namespace {
|
||||
|
||||
// A priority queue that holds stage roots where the top of the priority queue
|
||||
// is the node with the largest total time.
|
||||
class ModelTimingPriorityQueue {
|
||||
public:
|
||||
explicit ModelTimingPriorityQueue(ModelTiming& model_timing) {
|
||||
std::vector<std::shared_ptr<Node>> stage_roots =
|
||||
model_timing.GetStageRoots();
|
||||
if (stage_roots.empty()) {
|
||||
return;
|
||||
}
|
||||
for (auto& root : stage_roots) {
|
||||
DCHECK(model_timing.GetTiming(root.get()) != nullptr);
|
||||
stage_roots_queue_.emplace(
|
||||
model_timing.GetTiming(root.get())->total_time_nsec, root.get());
|
||||
}
|
||||
}
|
||||
|
||||
// Pops the top item from the queue, i.e. node with the largest total time.
|
||||
StatusOr<std::pair<double, Node*>> PopSlowestStageRoot() {
|
||||
if (stage_roots_queue_.empty()) {
|
||||
return errors::Internal(
|
||||
"Model timing priority queue is empty during stage-based "
|
||||
"optimization");
|
||||
}
|
||||
std::pair<double, Node*> top_item = stage_roots_queue_.top();
|
||||
stage_roots_queue_.pop();
|
||||
return top_item;
|
||||
}
|
||||
|
||||
// Push a node together with its total time onto the queue.
|
||||
void Push(Node* node, double total_time_nsec) {
|
||||
stage_roots_queue_.emplace(total_time_nsec, node);
|
||||
}
|
||||
|
||||
private:
|
||||
std::priority_queue<std::pair<double, Node*>> stage_roots_queue_;
|
||||
};
|
||||
|
||||
// A cache that looks up the `parallelism` parameters of nodes the first time
|
||||
// they are requested and saves them for subsequent requests.
|
||||
class NodeParallelismParameters {
|
||||
public:
|
||||
NodeParallelismParameters() {}
|
||||
|
||||
// Returns the `parallelism` parameter given a node.
|
||||
Parameter* Get(const Node* node) {
|
||||
if (node_parallelism_.contains(node)) {
|
||||
// Look for the `parallelism` parameter of this node in the cache.
|
||||
return node_parallelism_.at(node);
|
||||
}
|
||||
// Find the `parallelism` parameter of this node and cache it.
|
||||
Node::ModelParameters parameters = node->CollectNodeTunableParameters();
|
||||
Node::ModelParameters::iterator parameter_pair = std::find_if(
|
||||
parameters.begin(), parameters.end(),
|
||||
[](const std::pair<std::string, std::shared_ptr<Parameter>>&
|
||||
parameter) { return parameter.second->name == kParallelism; });
|
||||
if (parameter_pair == parameters.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
node_parallelism_[node] = parameter_pair->second.get();
|
||||
return parameter_pair->second.get();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<const Node*, Parameter*> node_parallelism_;
|
||||
};
|
||||
|
||||
// Returns true if all parameters have reached their max values.
|
||||
bool AreAllParametersMax(const Model::ModelParameters& parameters) {
|
||||
for (const auto& pair : parameters) {
|
||||
|
|
@ -177,6 +246,10 @@ Status ModelToProtoHelper(std::shared_ptr<Node> output, ModelProto* model) {
|
|||
|
||||
// Recursively produces node tree rooted in `output` from the given model proto.
|
||||
Status ModelFromProtoHelper(ModelProto model, std::shared_ptr<Node>* output) {
|
||||
if (model.nodes().empty()) {
|
||||
return errors::Internal(
|
||||
"Cannot restore model from proto because it has no nodes.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(Node::FromProto(model.nodes().at(model.output()),
|
||||
/*output=*/nullptr, output));
|
||||
std::list<std::shared_ptr<Node>> to_restore_inputs = {*output};
|
||||
|
|
@ -1386,6 +1459,13 @@ Node::ModelParameters Node::CollectTunableParameters() const {
|
|||
return CollectTunableParametersLocked();
|
||||
}
|
||||
|
||||
Node::ModelParameters Node::CollectNodeTunableParameters() const {
|
||||
tf_shared_lock l(mu_);
|
||||
Node::ModelParameters parameters;
|
||||
CollectTunableParametersHelper(¶meters);
|
||||
return parameters;
|
||||
}
|
||||
|
||||
string Node::DebugString() const {
|
||||
absl::flat_hash_map<string, string> debug_strings;
|
||||
tf_shared_lock l(mu_);
|
||||
|
|
@ -1943,6 +2023,9 @@ void Model::Optimize(AutotuneAlgorithm algorithm, int64_t cpu_budget,
|
|||
OptimizeGradientDescent(snapshot, optimization_params,
|
||||
cancellation_manager);
|
||||
break;
|
||||
case AutotuneAlgorithm::STAGE_BASED:
|
||||
OptimizeStageBased(snapshot, optimization_params, cancellation_manager);
|
||||
break;
|
||||
default:
|
||||
VLOG(2) << "Autotuning algorithm was not recognized. Aborting "
|
||||
"optimization.";
|
||||
|
|
@ -2022,7 +2105,15 @@ Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, int64_t cpu_budget,
|
|||
}
|
||||
|
||||
int64_t start_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
|
||||
Optimize(algorithm, cpu_budget, ram_budget, /*model_input_time=*/0,
|
||||
double model_input_time = 0.0;
|
||||
// Model input time is set to 0 for all optimization algorithms except for
|
||||
// stage-based optimization algorithm for historical reason. In stage-based
|
||||
// optimization algorithm, the model input time is used as a target
|
||||
// optimization time of all stages in the pipeline.
|
||||
if (algorithm == AutotuneAlgorithm::STAGE_BASED) {
|
||||
model_input_time = ComputeTargetTimeNsec();
|
||||
}
|
||||
Optimize(algorithm, cpu_budget, ram_budget, model_input_time,
|
||||
cancellation_manager);
|
||||
int64_t end_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
|
||||
VLOG(2) << "Optimized for " << end_ms - start_ms << " ms.";
|
||||
|
|
@ -2167,6 +2258,82 @@ void Model::OptimizeHillClimbHelper(
|
|||
UpdateStateValues(¶meters);
|
||||
}
|
||||
|
||||
double Model::ComputeTargetTimeNsec() {
|
||||
tf_shared_lock l(gap_mu_);
|
||||
if (gap_time_count_ == 0) {
|
||||
return 0.0;
|
||||
}
|
||||
return (static_cast<double>(gap_time_sum_usec_) /
|
||||
static_cast<double>(gap_time_count_)) *
|
||||
1.0e6;
|
||||
}
|
||||
|
||||
void Model::OptimizeStageBased(std::shared_ptr<Node> snapshot,
|
||||
const OptimizationParams& optimization_params,
|
||||
CancellationManager* cancellation_manager) {
|
||||
return OptimizeStageBasedParallelism(
|
||||
snapshot, optimization_params.model_input_time(), optimization_params,
|
||||
cancellation_manager);
|
||||
}
|
||||
|
||||
void Model::OptimizeStageBasedParallelism(
|
||||
std::shared_ptr<Node> snapshot, double target_time_nsec,
|
||||
const OptimizationParams& optimization_params,
|
||||
CancellationManager* cancellation_manager) {
|
||||
VLOG(2) << "Starting optimization of tunable parameters with Stage-Based "
|
||||
"optimization with a target time of "
|
||||
<< optimization_params.model_input_time() << " nanoseconds.";
|
||||
Node::ModelParameters tunable_parameters = CollectTunableParameters(snapshot);
|
||||
// Initialize the parallelism parameter values to minimal before tuning.
|
||||
for (std::pair<string, std::shared_ptr<Parameter>>& pair :
|
||||
tunable_parameters) {
|
||||
if (pair.second->name != kParallelism) {
|
||||
continue;
|
||||
}
|
||||
pair.second->value = pair.second->min;
|
||||
}
|
||||
ModelTiming model_timing(snapshot);
|
||||
ModelTimingPriorityQueue priority_queue(model_timing);
|
||||
StatusOr<std::pair<double, Node*>> critical_root_status =
|
||||
priority_queue.PopSlowestStageRoot();
|
||||
if (!critical_root_status.ok()) {
|
||||
return;
|
||||
}
|
||||
NodeParallelismParameters node_parallelism;
|
||||
std::pair<double, Node*> critical_root = critical_root_status.ValueOrDie();
|
||||
while (critical_root.first > target_time_nsec) {
|
||||
Parameter* parallelism_parameter =
|
||||
node_parallelism.Get(critical_root.second);
|
||||
// Stop optimization if the critical stage has no `parallelism` parameter or
|
||||
// it has reached the max parallelism value.
|
||||
if (parallelism_parameter == nullptr ||
|
||||
parallelism_parameter->value >= optimization_params.cpu_budget()) {
|
||||
break;
|
||||
}
|
||||
parallelism_parameter->value += 1.0;
|
||||
if (TotalMaximumBufferedBytes(snapshot) >
|
||||
optimization_params.ram_budget()) {
|
||||
// Increasing the parallelism by 1 exceeded ram budget. Reduce it back and
|
||||
// stop optimization because we cannot improve the most critical stage.
|
||||
parallelism_parameter->value -= 1.0;
|
||||
break;
|
||||
}
|
||||
// Compute the new total time and put the node back in the queue after its
|
||||
// parallelism value has been increased by 1.
|
||||
model_timing.ComputeNodeTotalTime(*critical_root.second);
|
||||
priority_queue.Push(
|
||||
critical_root.second,
|
||||
model_timing.GetTiming(critical_root.second)->total_time_nsec);
|
||||
// Get the next critical stage root.
|
||||
critical_root_status = priority_queue.PopSlowestStageRoot();
|
||||
if (!critical_root_status.ok()) {
|
||||
break;
|
||||
}
|
||||
critical_root = critical_root_status.ValueOrDie();
|
||||
}
|
||||
UpdateStateValues(&tunable_parameters);
|
||||
}
|
||||
|
||||
void Model::OptimizeHillClimb(std::shared_ptr<Node> snapshot,
|
||||
const OptimizationParams& optimization_params,
|
||||
CancellationManager* cancellation_manager) {
|
||||
|
|
@ -2309,9 +2476,18 @@ std::string Model::DebugString() {
|
|||
return cached_debug_string_;
|
||||
}
|
||||
|
||||
Node::NodeVector Model::CollectNodes(
|
||||
ModelTiming::ModelTiming(std::shared_ptr<Node> root) : root_(root) {
|
||||
DCHECK(root_.get() != nullptr);
|
||||
auto bfs_nodes = CollectNodes(root_, TraversalOrder::BFS, IsAnyNode);
|
||||
auto reverse_bfs_nodes = bfs_nodes;
|
||||
std::reverse(reverse_bfs_nodes.begin(), reverse_bfs_nodes.end());
|
||||
ComputePipelineRatios(bfs_nodes);
|
||||
ComputeTotalTimes(reverse_bfs_nodes);
|
||||
}
|
||||
|
||||
Node::NodeVector ModelTiming::CollectNodes(
|
||||
std::shared_ptr<Node> root, TraversalOrder order,
|
||||
bool collect_node(const std::shared_ptr<Node>)) {
|
||||
bool collect_node(const std::shared_ptr<Node>)) const {
|
||||
if (root == nullptr) {
|
||||
return Node::NodeVector({});
|
||||
}
|
||||
|
|
@ -2327,18 +2503,6 @@ Node::NodeVector Model::CollectNodes(
|
|||
return nodes;
|
||||
}
|
||||
|
||||
ModelTiming::ModelTiming(std::shared_ptr<Model> model) : model_(model) {
|
||||
ComputeTiming();
|
||||
}
|
||||
|
||||
void ModelTiming::ComputeTiming() {
|
||||
auto nodes =
|
||||
model_->CollectNodes(model_->output(), TraversalOrder::BFS, IsAnyNode);
|
||||
ComputeTimingComponents(nodes);
|
||||
std::reverse(nodes.begin(), nodes.end());
|
||||
ComputeTotalTimes(nodes);
|
||||
}
|
||||
|
||||
const ModelTiming::NodeTiming* ModelTiming::GetTiming(const Node* node) const {
|
||||
if (timing_nodes_.find(node) == timing_nodes_.end()) {
|
||||
return nullptr;
|
||||
|
|
@ -2346,10 +2510,9 @@ const ModelTiming::NodeTiming* ModelTiming::GetTiming(const Node* node) const {
|
|||
return &(timing_nodes_.at(node));
|
||||
}
|
||||
|
||||
void ModelTiming::ComputeTimingComponents(const Node::NodeVector& bfs_nodes) {
|
||||
void ModelTiming::ComputePipelineRatios(const Node::NodeVector& bfs_nodes) {
|
||||
for (const auto& node : bfs_nodes) {
|
||||
auto& node_timing = timing_nodes_[node.get()];
|
||||
node_timing.self_time_nsec = node->ComputeSelfTime();
|
||||
if (!node->autotune()) {
|
||||
// These are inactive nodes marked by parallel interleave transformations.
|
||||
node_timing.pipeline_ratio = 0.0;
|
||||
|
|
@ -2370,31 +2533,33 @@ void ModelTiming::ComputeTimingComponents(const Node::NodeVector& bfs_nodes) {
|
|||
}
|
||||
}
|
||||
|
||||
void ModelTiming::ComputeNodeTotalTime(std::shared_ptr<Node> node) {
|
||||
DCHECK(timing_nodes_.contains(node.get()));
|
||||
auto& node_timing = timing_nodes_[node.get()];
|
||||
void ModelTiming::ComputeNonAsyncInterleaveManyTotalTime(const Node& node) {
|
||||
DCHECK(timing_nodes_.contains(&node));
|
||||
auto& node_timing = timing_nodes_[&node];
|
||||
double input_total_time_nsec = 0.0;
|
||||
for (auto input : node->inputs()) {
|
||||
for (auto input : node.inputs()) {
|
||||
if (input->IsAsync()) {
|
||||
continue;
|
||||
}
|
||||
if (!input->autotune() || input->num_elements() <= 0) {
|
||||
continue;
|
||||
}
|
||||
DCHECK(timing_nodes_.contains(input.get()));
|
||||
DCHECK(timing_nodes_.contains(input.get()))
|
||||
<< "Input " << input->long_name() << " of node " << node.long_name()
|
||||
<< " has no timing node.";
|
||||
|
||||
input_total_time_nsec += timing_nodes_[input.get()].total_time_nsec;
|
||||
}
|
||||
node_timing.total_time_nsec =
|
||||
node_timing.self_time_nsec + input_total_time_nsec * node->Ratio();
|
||||
node_timing.self_time_nsec + input_total_time_nsec * node.Ratio();
|
||||
}
|
||||
|
||||
void ModelTiming::ComputeAsyncInterleaveManyTotalTime(
|
||||
std::shared_ptr<Node> node) {
|
||||
DCHECK(timing_nodes_.contains(node));
|
||||
auto& node_timing = timing_nodes_[node.get()];
|
||||
void ModelTiming::ComputeAsyncInterleaveManyTotalTime(const Node& node) {
|
||||
DCHECK(timing_nodes_.contains(&node));
|
||||
auto& node_timing = timing_nodes_[&node];
|
||||
double max_input_total_time_nsec = 0.0;
|
||||
double sum_input_throughput = 0.0;
|
||||
auto inputs = node->inputs();
|
||||
auto inputs = node.inputs();
|
||||
// `ParallelInterleave` is often used to interleave processing of datasets
|
||||
// generated from the first input, e.g. reading from IO where the first input
|
||||
// has the list of all filenames. The first input is typically not the
|
||||
|
|
@ -2409,7 +2574,9 @@ void ModelTiming::ComputeAsyncInterleaveManyTotalTime(
|
|||
if (!(*input)->autotune() || (*input)->num_elements() <= 0) {
|
||||
continue;
|
||||
}
|
||||
DCHECK(timing_nodes_.contains((*input).get()));
|
||||
DCHECK(timing_nodes_.contains((*input).get()))
|
||||
<< "Input " << (*input)->long_name() << " of node " << node.long_name()
|
||||
<< " has no timing node.";
|
||||
auto input_total_time_nsec = timing_nodes_[(*input).get()].total_time_nsec;
|
||||
max_input_total_time_nsec =
|
||||
std::max(input_total_time_nsec, max_input_total_time_nsec);
|
||||
|
|
@ -2418,14 +2585,14 @@ void ModelTiming::ComputeAsyncInterleaveManyTotalTime(
|
|||
}
|
||||
}
|
||||
double input_total_time_nsec = 0.0;
|
||||
auto deterministic = node->ParameterValue(kDeterministic);
|
||||
auto deterministic = node.ParameterValue(kDeterministic);
|
||||
// After cl/445005635, there should always be `deterministic` parameter for an
|
||||
// ASYNC_INTERLEAVE_MANY node. The "not-ok" check is to allow the code to work
|
||||
// with protos saved and restored before that CL.
|
||||
if (!deterministic.ok() || deterministic.ValueOrDie() == 1.0) {
|
||||
// If deterministic = true, then the total time is `1/worst input
|
||||
// throughput * cycle_length`, or `max input total time / cycle_length`.
|
||||
input_total_time_nsec = max_input_total_time_nsec * node->Ratio();
|
||||
input_total_time_nsec = max_input_total_time_nsec * node.Ratio();
|
||||
} else if (sum_input_throughput > 0.0) {
|
||||
// If deterministic = false, then the total time is
|
||||
// `1/sum_input_throughput`.
|
||||
|
|
@ -2437,26 +2604,31 @@ void ModelTiming::ComputeAsyncInterleaveManyTotalTime(
|
|||
|
||||
void ModelTiming::ComputeTotalTimes(const Node::NodeVector& reverse_bfs_nodes) {
|
||||
for (const auto& node : reverse_bfs_nodes) {
|
||||
if (!node->autotune() || node->num_elements() <= 0) {
|
||||
continue;
|
||||
ComputeNodeTotalTime(*(node.get()));
|
||||
}
|
||||
}
|
||||
|
||||
void ModelTiming::ComputeNodeTotalTime(const Node& node) {
|
||||
NodeTiming& node_timing = timing_nodes_[&node];
|
||||
node_timing.self_time_nsec = node.ComputeSelfTime();
|
||||
if (!node.autotune() || node.num_elements() <= 0) {
|
||||
return;
|
||||
}
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
// This block of code is defined only for non-mobile platform because mobile
|
||||
// platform lacks RTTI, i.e. the use of `dynamic_cast`.
|
||||
if (dynamic_cast<const AsyncInterleaveMany*>(node.get()) != nullptr) {
|
||||
if (dynamic_cast<const AsyncInterleaveMany*>(&node) != nullptr) {
|
||||
ComputeAsyncInterleaveManyTotalTime(node);
|
||||
} else {
|
||||
ComputeNodeTotalTime(node);
|
||||
ComputeNonAsyncInterleaveManyTotalTime(node);
|
||||
}
|
||||
#else // !IS_MOBILE_PLATFORM
|
||||
ComputeNodeTotalTime(node);
|
||||
ComputeNonAsyncInterleaveManyTotalTime(node);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Node>> ModelTiming::GetStageRoots() const {
|
||||
auto bfs_nodes =
|
||||
model_->CollectNodes(model_->output(), TraversalOrder::BFS, IsAnyNode);
|
||||
auto bfs_nodes = CollectNodes(root_, TraversalOrder::BFS, IsAnyNode);
|
||||
std::vector<std::shared_ptr<Node>> roots;
|
||||
if (!bfs_nodes.empty() && !bfs_nodes[0]->IsAsync()) {
|
||||
roots.push_back(bfs_nodes[0]);
|
||||
|
|
@ -2470,8 +2642,8 @@ std::vector<std::shared_ptr<Node>> ModelTiming::GetStageRoots() const {
|
|||
}
|
||||
|
||||
std::vector<std::shared_ptr<Node>> ModelTiming::GetStageNodes(
|
||||
std::shared_ptr<Node> root) const {
|
||||
return model_->CollectNodes(root, TraversalOrder::BFS, IsSyncNode);
|
||||
std::shared_ptr<Node> stage_root) const {
|
||||
return CollectNodes(stage_root, TraversalOrder::BFS, IsSyncNode);
|
||||
}
|
||||
|
||||
} // namespace model
|
||||
|
|
|
|||
|
|
@ -375,6 +375,9 @@ class Node {
|
|||
// Collects tunable parameters in the subtree rooted in this node.
|
||||
ModelParameters CollectTunableParameters() const TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Collects tunable parameters in this node.
|
||||
ModelParameters CollectNodeTunableParameters() const TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Returns a human-readable representation of this node.
|
||||
string DebugString() const TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
|
|
@ -689,7 +692,7 @@ class Model {
|
|||
~Model();
|
||||
|
||||
// Returns a pointer to the model's output node.
|
||||
const std::shared_ptr<Node> output() {
|
||||
const std::shared_ptr<Node> output() const {
|
||||
mutex_lock l(mu_);
|
||||
return output_;
|
||||
}
|
||||
|
|
@ -745,9 +748,12 @@ class Model {
|
|||
static Status Load(const string& fname, std::unique_ptr<Model>* model,
|
||||
OptimizationParams* optimization_params);
|
||||
|
||||
Node::NodeVector CollectNodes(std::shared_ptr<Node> root,
|
||||
TraversalOrder order,
|
||||
bool collect_node(const std::shared_ptr<Node>));
|
||||
// Record gap time between consecutive `GetNext()` calls.
|
||||
void RecordIteratorGapTime(uint64_t duration_usec) {
|
||||
mutex_lock l(gap_mu_);
|
||||
gap_time_sum_usec_ += duration_usec;
|
||||
++gap_time_count_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Determines whether optimization should stop given total processing time,
|
||||
|
|
@ -801,6 +807,27 @@ class Model {
|
|||
const OptimizationParams& optimization_params,
|
||||
CancellationManager* cancellation_manager);
|
||||
|
||||
// This optimization starts by setting all tunable parallelism parameters to
|
||||
// their minimum values. It then repeatedly increases the parallelism
|
||||
// parameter of the longest stage by 1 until either the longest stage is
|
||||
// faster than the target time or the memory or CPU budget is fully utilized.
|
||||
// TODO(b/226910071): The second part of this algorithm optimizes the buffer
|
||||
// sizes of parallel ops.
|
||||
void OptimizeStageBased(std::shared_ptr<Node> snapshot,
|
||||
const OptimizationParams& optimization_params,
|
||||
CancellationManager* cancellation_manager);
|
||||
|
||||
// Computes the target time in nsecs to use for `STAGE_BASED` autotune
|
||||
// algorithm.
|
||||
double ComputeTargetTimeNsec();
|
||||
|
||||
// This is the first part of the stage-based optimization that optimizes
|
||||
// tunable parallelism parameters.
|
||||
void OptimizeStageBasedParallelism(
|
||||
std::shared_ptr<Node> snapshot, double target_time_nsec,
|
||||
const OptimizationParams& optimization_params,
|
||||
CancellationManager* cancellation_manager);
|
||||
|
||||
// Determines if we should stop the gradient descent optimization iterations
|
||||
// based on number of increasable parameters, CPU budget, RAM budget and
|
||||
// current resource usage.
|
||||
|
|
@ -826,7 +853,7 @@ class Model {
|
|||
// Used for coordination between different input pipeline threads. Exclusive
|
||||
// access is required only when adding or removing nodes. Concurrent access to
|
||||
// existing nodes is protected by a node mutex.
|
||||
mutex mu_;
|
||||
mutable mutex mu_;
|
||||
// Used for coordinating the optimization loop and model modifications.
|
||||
condition_variable optimize_cond_var_;
|
||||
int64_t id_counter_ TF_GUARDED_BY(mu_) = 1;
|
||||
|
|
@ -845,6 +872,13 @@ class Model {
|
|||
// Cached result of the `DebugString()` invocation used to implement rate
|
||||
// limitting of the computation.
|
||||
std::string cached_debug_string_ = "";
|
||||
// Used to coordinate gap time updates between different threads. Gap time is
|
||||
// the time between the completion of the previous `GetNext()` and the start
|
||||
// of the next `GetNext()`.
|
||||
mutable mutex gap_mu_;
|
||||
// Gap time between consecutive `GetNext()` for a model.
|
||||
uint64_t gap_time_sum_usec_ TF_GUARDED_BY(gap_mu_) = 0;
|
||||
uint64_t gap_time_count_ TF_GUARDED_BY(gap_mu_) = 0;
|
||||
};
|
||||
|
||||
// Class to compute timing information for a model.
|
||||
|
|
@ -863,7 +897,7 @@ class ModelTiming {
|
|||
double total_time_nsec = 0.0;
|
||||
};
|
||||
|
||||
explicit ModelTiming(std::shared_ptr<Model> model);
|
||||
explicit ModelTiming(std::shared_ptr<Node> root);
|
||||
|
||||
// Returns the timing data for `node`.
|
||||
const NodeTiming* GetTiming(const Node* node) const;
|
||||
|
|
@ -873,28 +907,37 @@ class ModelTiming {
|
|||
|
||||
// Returns all the nodes of a stage given the stage root.
|
||||
std::vector<std::shared_ptr<Node>> GetStageNodes(
|
||||
std::shared_ptr<Node> root) const;
|
||||
std::shared_ptr<Node> stage_root) const;
|
||||
|
||||
// Computes the total time for a node.
|
||||
void ComputeNodeTotalTime(const Node& node);
|
||||
|
||||
private:
|
||||
// Computes timing information for the whole model.
|
||||
void ComputeTiming();
|
||||
|
||||
// Computes the pipeline ratio, self time for all nodes. The `bfs_nodes` are
|
||||
// assumed to be a vector of model nodes in BFS manner.
|
||||
void ComputeTimingComponents(const Node::NodeVector& bfs_nodes);
|
||||
// Computes the pipeline ratios of all nodes.
|
||||
void ComputePipelineRatios(const Node::NodeVector& bfs_nodes);
|
||||
|
||||
// Computes the total time for all nodes. The `reverse_bfs_nodes` are assumed
|
||||
// to be a vector of model nodes in reversed BFS manner.
|
||||
void ComputeTotalTimes(const Node::NodeVector& reverse_bfs_nodes);
|
||||
|
||||
// Computes the total time for a node except when the node is an async
|
||||
// interleave node.
|
||||
void ComputeNodeTotalTime(std::shared_ptr<Node> node);
|
||||
// Computes the total time of a node that is not an async interleave node.
|
||||
void ComputeNonAsyncInterleaveManyTotalTime(const Node& node);
|
||||
|
||||
// Computes the total time of an async interleave node.
|
||||
void ComputeAsyncInterleaveManyTotalTime(std::shared_ptr<Node> node);
|
||||
void ComputeAsyncInterleaveManyTotalTime(const Node& node);
|
||||
|
||||
std::shared_ptr<Model> model_;
|
||||
// Returns a vector of all nodes in the model. The nodes are either in
|
||||
// breadth-first search or reverse breadth-first search order depending on the
|
||||
// `order` argument. The nodes are collected based on the results of the
|
||||
// `collect_node` predicate: if the predicate returns `false` for a given
|
||||
// node, then the subtree rooted in this node is excluded. The root node
|
||||
// itself is not collected.
|
||||
Node::NodeVector CollectNodes(
|
||||
std::shared_ptr<Node> root, TraversalOrder order,
|
||||
bool collect_node(const std::shared_ptr<Node>)) const;
|
||||
|
||||
// Stores a pointer to the root of a model.
|
||||
std::shared_ptr<Node> root_;
|
||||
|
||||
// Holds a mapping from node to its timing node.
|
||||
absl::flat_hash_map<const Node*, NodeTiming> timing_nodes_;
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ enum AutotuneAlgorithm {
|
|||
HILL_CLIMB = 1;
|
||||
GRADIENT_DESCENT = 2;
|
||||
MAX_PARALLELISM = 3;
|
||||
STAGE_BASED = 4;
|
||||
}
|
||||
|
||||
// Protocol buffer representing the data used by the autotuning modeling
|
||||
|
|
|
|||
|
|
@ -1321,22 +1321,38 @@ TEST(RecordTimeTest, RecordTimeTest) {
|
|||
EXPECT_FALSE(source->is_recording());
|
||||
}
|
||||
|
||||
TEST(ModelTest, ModelMetrics) {
|
||||
CellReader<std::string> cell_reader("/tensorflow/data/model");
|
||||
model::Model model;
|
||||
std::shared_ptr<Node> root = model::MakeUnknownNode({0, "unknown0", nullptr});
|
||||
model.AddNode([&root](model::Node::Args args) { return root; }, root->name(),
|
||||
nullptr, &root);
|
||||
std::string model_id = strings::StrCat(reinterpret_cast<uint64>(&model));
|
||||
EXPECT_THAT(cell_reader.Read(model_id),
|
||||
AllOf(HasSubstr("key: 0"), HasSubstr("name: \"unknown0\""),
|
||||
HasSubstr("autotune: true")));
|
||||
}
|
||||
|
||||
class ModelTimingTest : public ::testing::Test {
|
||||
public:
|
||||
// Computes the timing given a Model text proto.
|
||||
void ComputeModelTiming(const std::string& model_pbtxt) {
|
||||
// Builds a Model from its text proto.
|
||||
void BuildModelFromProto(const std::string& model_pbtxt) {
|
||||
ModelProto model_proto;
|
||||
protobuf::TextFormat::ParseFromString(model_pbtxt, &model_proto);
|
||||
std::unique_ptr<Model> model;
|
||||
TF_CHECK_OK(Model::FromProto(model_proto, &model));
|
||||
auto nodes =
|
||||
model->CollectNodes(model->output(), TraversalOrder::BFS,
|
||||
[](const std::shared_ptr<Node>) { return true; });
|
||||
TF_CHECK_OK(Model::FromProto(model_proto, &model_));
|
||||
auto nodes = model_->output()->CollectNodes(
|
||||
TraversalOrder::BFS, [](const std::shared_ptr<Node>) { return true; });
|
||||
node_map_.clear();
|
||||
node_map_[model_->output()->id()] = model_->output().get();
|
||||
for (const auto& node : nodes) {
|
||||
node_map_[node->id()] = node.get();
|
||||
}
|
||||
model_timing_ = absl::make_unique<ModelTiming>(std::move(model));
|
||||
}
|
||||
|
||||
// Computes the timing given a Model text proto.
|
||||
void ComputeModelTiming(const std::string& model_pbtxt) {
|
||||
BuildModelFromProto(model_pbtxt);
|
||||
model_timing_ = std::make_unique<ModelTiming>(model_->output());
|
||||
}
|
||||
|
||||
// Gets the timing information of a node given its id.
|
||||
|
|
@ -1344,7 +1360,11 @@ class ModelTimingTest : public ::testing::Test {
|
|||
return model_timing_->GetTiming(node_map_.at(node_id));
|
||||
}
|
||||
|
||||
// Gets the node given its id.
|
||||
const Node* GetNode(int64_t node_id) const { return node_map_.at(node_id); }
|
||||
|
||||
protected:
|
||||
std::unique_ptr<Model> model_;
|
||||
std::unique_ptr<ModelTiming> model_timing_;
|
||||
absl::flat_hash_map<int64_t, const Node*> node_map_;
|
||||
};
|
||||
|
|
@ -1883,16 +1903,263 @@ TEST_F(ModelTimingTest, ParallelInterleave_Batch_ParallelMap) {
|
|||
EXPECT_DOUBLE_EQ(10, GetNodeTiming(/*node_id=*/7)->total_time_nsec);
|
||||
}
|
||||
|
||||
TEST(ModelTest, ModelMetrics) {
|
||||
CellReader<std::string> cell_reader("/tensorflow/data/model");
|
||||
model::Model model;
|
||||
std::shared_ptr<Node> root = model::MakeUnknownNode({0, "unknown0", nullptr});
|
||||
model.AddNode([&root](model::Node::Args args) { return root; }, root->name(),
|
||||
nullptr, &root);
|
||||
std::string model_id = strings::StrCat(reinterpret_cast<uint64>(&model));
|
||||
EXPECT_THAT(cell_reader.Read(model_id),
|
||||
AllOf(HasSubstr("key: 0"), HasSubstr("name: \"unknown0\""),
|
||||
HasSubstr("autotune: true")));
|
||||
TEST_F(ModelTimingTest, OptimizeGreedy_OneStage) {
|
||||
BuildModelFromProto(R"pb(
|
||||
nodes: {
|
||||
key: 1
|
||||
value: {
|
||||
id: 1
|
||||
name: "ParallelMapV2"
|
||||
autotune: true
|
||||
num_elements: 100
|
||||
processing_time: 5000
|
||||
bytes_produced: 10000
|
||||
node_class: ASYNC_KNOWN_RATIO
|
||||
ratio: 1
|
||||
inputs: 2
|
||||
parameters: {
|
||||
name: "parallelism"
|
||||
value: 4
|
||||
min: 1
|
||||
max: 16
|
||||
tunable: true
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes: {
|
||||
key: 2
|
||||
value: {
|
||||
id: 2
|
||||
name: "Map"
|
||||
autotune: true
|
||||
num_elements: 100
|
||||
processing_time: 3000
|
||||
node_class: KNOWN_RATIO
|
||||
ratio: 1
|
||||
inputs: 3
|
||||
}
|
||||
}
|
||||
nodes: {
|
||||
key: 3
|
||||
value: {
|
||||
id: 3
|
||||
name: "SSTable"
|
||||
autotune: true
|
||||
num_elements: 100
|
||||
processing_time: 1000
|
||||
node_class: KNOWN_RATIO
|
||||
ratio: 2
|
||||
}
|
||||
}
|
||||
output: 1
|
||||
)pb");
|
||||
|
||||
CancellationManager cancellation_manager;
|
||||
model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 5, 1000, 50,
|
||||
&cancellation_manager);
|
||||
|
||||
EXPECT_EQ(5, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
|
||||
}
|
||||
|
||||
TEST_F(ModelTimingTest, OptimizeGreedy_TwoStages) {
|
||||
BuildModelFromProto(R"pb(
|
||||
nodes: {
|
||||
key: 1
|
||||
value: {
|
||||
id: 1
|
||||
name: "ParallelMapV2"
|
||||
autotune: true
|
||||
num_elements: 100
|
||||
processing_time: 25000
|
||||
bytes_produced: 10000
|
||||
node_class: ASYNC_KNOWN_RATIO
|
||||
ratio: 1
|
||||
inputs: 2
|
||||
parameters: {
|
||||
name: "parallelism"
|
||||
value: 4
|
||||
min: 1
|
||||
max: 16
|
||||
tunable: true
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes: {
|
||||
key: 2
|
||||
value: {
|
||||
id: 2
|
||||
name: "ParallelMapV2"
|
||||
autotune: true
|
||||
num_elements: 100
|
||||
processing_time: 20000
|
||||
bytes_produced: 10000
|
||||
node_class: ASYNC_KNOWN_RATIO
|
||||
ratio: 1
|
||||
inputs: 3
|
||||
parameters: {
|
||||
name: "parallelism"
|
||||
value: 4
|
||||
min: 1
|
||||
max: 16
|
||||
tunable: true
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes: {
|
||||
key: 3
|
||||
value: {
|
||||
id: 3
|
||||
name: "SSTable"
|
||||
autotune: true
|
||||
num_elements: 100
|
||||
processing_time: 1000
|
||||
node_class: KNOWN_RATIO
|
||||
ratio: 2
|
||||
}
|
||||
}
|
||||
output: 1
|
||||
)pb");
|
||||
|
||||
CancellationManager cancellation_manager;
|
||||
model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 5, 1000, 50,
|
||||
&cancellation_manager);
|
||||
|
||||
EXPECT_EQ(5, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
|
||||
EXPECT_EQ(5, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
|
||||
}
|
||||
|
||||
TEST_F(ModelTimingTest, OptimizeGreedy_TwoStages_RamBudgetExceeded) {
|
||||
BuildModelFromProto(R"pb(
|
||||
nodes: {
|
||||
key: 1
|
||||
value: {
|
||||
id: 1
|
||||
name: "ParallelMapV2"
|
||||
autotune: true
|
||||
num_elements: 100
|
||||
processing_time: 25000
|
||||
bytes_produced: 10000
|
||||
node_class: ASYNC_KNOWN_RATIO
|
||||
ratio: 1
|
||||
inputs: 2
|
||||
parameters: {
|
||||
name: "parallelism"
|
||||
value: 4
|
||||
min: 1
|
||||
max: 16
|
||||
tunable: true
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes: {
|
||||
key: 2
|
||||
value: {
|
||||
id: 2
|
||||
name: "ParallelMapV2"
|
||||
autotune: true
|
||||
num_elements: 100
|
||||
processing_time: 20000
|
||||
bytes_produced: 10000
|
||||
node_class: ASYNC_KNOWN_RATIO
|
||||
ratio: 1
|
||||
inputs: 3
|
||||
parameters: {
|
||||
name: "parallelism"
|
||||
value: 4
|
||||
min: 1
|
||||
max: 16
|
||||
tunable: true
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes: {
|
||||
key: 3
|
||||
value: {
|
||||
id: 3
|
||||
name: "SSTable"
|
||||
autotune: true
|
||||
num_elements: 100
|
||||
processing_time: 1000
|
||||
node_class: KNOWN_RATIO
|
||||
ratio: 2
|
||||
}
|
||||
}
|
||||
output: 1
|
||||
)pb");
|
||||
|
||||
CancellationManager cancellation_manager;
|
||||
model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 5, 800, 50,
|
||||
&cancellation_manager);
|
||||
|
||||
EXPECT_EQ(4, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
|
||||
EXPECT_EQ(4, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
|
||||
}
|
||||
|
||||
TEST_F(ModelTimingTest, OptimizeGreedy_TwoStages_CpuBudgetExceeded) {
|
||||
BuildModelFromProto(R"pb(
|
||||
nodes: {
|
||||
key: 1
|
||||
value: {
|
||||
id: 1
|
||||
name: "ParallelMapV2"
|
||||
autotune: true
|
||||
num_elements: 100
|
||||
processing_time: 25000
|
||||
bytes_produced: 10000
|
||||
node_class: ASYNC_KNOWN_RATIO
|
||||
ratio: 1
|
||||
inputs: 2
|
||||
parameters: {
|
||||
name: "parallelism"
|
||||
value: 4
|
||||
min: 1
|
||||
max: 16
|
||||
tunable: true
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes: {
|
||||
key: 2
|
||||
value: {
|
||||
id: 2
|
||||
name: "ParallelMapV2"
|
||||
autotune: true
|
||||
num_elements: 100
|
||||
processing_time: 20000
|
||||
bytes_produced: 10000
|
||||
node_class: ASYNC_KNOWN_RATIO
|
||||
ratio: 1
|
||||
inputs: 3
|
||||
parameters: {
|
||||
name: "parallelism"
|
||||
value: 4
|
||||
min: 1
|
||||
max: 16
|
||||
tunable: true
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes: {
|
||||
key: 3
|
||||
value: {
|
||||
id: 3
|
||||
name: "SSTable"
|
||||
autotune: true
|
||||
num_elements: 100
|
||||
processing_time: 1000
|
||||
node_class: KNOWN_RATIO
|
||||
ratio: 2
|
||||
}
|
||||
}
|
||||
output: 1
|
||||
)pb");
|
||||
|
||||
CancellationManager cancellation_manager;
|
||||
model_->Optimize(AutotuneAlgorithm::STAGE_BASED, 3, 1000, 50,
|
||||
&cancellation_manager);
|
||||
|
||||
EXPECT_EQ(3, GetNode(/*node_id=*/1)->parameter_value("parallelism"));
|
||||
EXPECT_EQ(3, GetNode(/*node_id=*/2)->parameter_value("parallelism"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -40,11 +40,15 @@ class AutotuneAlgorithm(enum.Enum):
|
|||
|
||||
MAX_PARALLELISM: Similar to HILL_CLIMB but uses a relaxed stopping condition,
|
||||
allowing the optimization to oversubscribe the CPU.
|
||||
|
||||
STAGE_BASED: In each optimization step, this algorithm chooses the worst
|
||||
bottleneck parameter and increases its value by 1.
|
||||
"""
|
||||
DEFAULT = 0
|
||||
HILL_CLIMB = 1
|
||||
GRADIENT_DESCENT = 2
|
||||
MAX_PARALLELISM = 3
|
||||
STAGE_BASED = 4
|
||||
|
||||
@classmethod
|
||||
def _to_proto(cls, obj):
|
||||
|
|
@ -56,9 +60,11 @@ class AutotuneAlgorithm(enum.Enum):
|
|||
return model_pb2.AutotuneAlgorithm.GRADIENT_DESCENT
|
||||
if obj == cls.MAX_PARALLELISM:
|
||||
return model_pb2.AutotuneAlgorithm.MAX_PARALLELISM
|
||||
if obj == cls.STAGE_BASED:
|
||||
return model_pb2.AutotuneAlgorithm.STAGE_BASED
|
||||
raise ValueError(
|
||||
f"Invalid `obj.` Supported values include `DEFAULT`, `HILL_CLIMB` and "
|
||||
f"`GRADIENT_DESCENT`. Got {obj.name}.")
|
||||
f"Invalid `obj.` Supported values include `DEFAULT`, `HILL_CLIMB` "
|
||||
f"`GRADIENT_DESCENT`, and `STAGE_BASED`. Got {obj.name}.")
|
||||
|
||||
@classmethod
|
||||
def _from_proto(cls, pb):
|
||||
|
|
@ -70,8 +76,11 @@ class AutotuneAlgorithm(enum.Enum):
|
|||
return cls.GRADIENT_DESCENT
|
||||
if pb == model_pb2.AutotuneAlgorithm.MAX_PARALLELISM:
|
||||
return cls.MAX_PARALLELISM
|
||||
raise ValueError(f"Invalid `pb.` Supported values include `DEFAULT`, "
|
||||
f"`HILL_CLIMB` and `GRADIENT_DESCENT`. Got {pb}.")
|
||||
if pb == model_pb2.AutotuneAlgorithm.STAGE_BASED:
|
||||
return cls.STAGE_BASED
|
||||
raise ValueError(
|
||||
f"Invalid `pb.` Supported values include `DEFAULT`, `HILL_CLIMB`, "
|
||||
f"`GRADIENT_DESCENT` and `STAGE_BASED`. Got {pb}.")
|
||||
|
||||
|
||||
@tf_export("data.experimental.AutoShardPolicy")
|
||||
|
|
|
|||
|
|
@ -17,4 +17,8 @@ tf_class {
|
|||
name: "MAX_PARALLELISM"
|
||||
mtype: "<enum \'AutotuneAlgorithm\'>"
|
||||
}
|
||||
member {
|
||||
name: "STAGE_BASED"
|
||||
mtype: "<enum \'AutotuneAlgorithm\'>"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,4 +17,8 @@ tf_class {
|
|||
name: "MAX_PARALLELISM"
|
||||
mtype: "<enum \'AutotuneAlgorithm\'>"
|
||||
}
|
||||
member {
|
||||
name: "STAGE_BASED"
|
||||
mtype: "<enum \'AutotuneAlgorithm\'>"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user