[tf.data] Implement warm_start feature for all the asynchronous operations.

PiperOrigin-RevId: 509533594
This commit is contained in:
A. Unique TensorFlower 2023-02-14 08:16:58 -08:00 committed by TensorFlower Gardener
parent dee9915773
commit d3509a44ca
24 changed files with 239 additions and 42 deletions

View File

@ -202,7 +202,15 @@ This release contains contributions from many people at Google, as well as:
`rerandomize_each_iteration=True`, the `sample_from_datasets()` `rerandomize_each_iteration=True`, the `sample_from_datasets()`
operation will use a different (deterministic) sequence of numbers every operation will use a different (deterministic) sequence of numbers every
epoch. epoch.
* Added a new field, `warm_start`, to
`tf.data.experimental.OptimizationOptions`. If it is set to `True`,
tf.data will start background threads of asynchronous
transformations upon iterator creation (as opposed to upon first call
to `GetNext`). To enable this behavior, set `warm_start=True` in
`tf.data.experimental.OptimizationOptions`. It should be noted that this
possibly improves the latency of the initial 'GetNext' call at the
expense of requiring more memory to hold prefetched elements between
the time of iterator construction and usage.
* `tf.test`: * `tf.test`:
* Added `tf.test.experimental.sync_devices`, which is useful for * Added `tf.test.experimental.sync_devices`, which is useful for

View File

@ -96,6 +96,7 @@ constexpr char kSlackOpt[] = "slack";
constexpr char kSlackPeriodOpt[] = "slack_period"; constexpr char kSlackPeriodOpt[] = "slack_period";
constexpr char kMakeDeterministicOpt[] = "make_deterministic"; constexpr char kMakeDeterministicOpt[] = "make_deterministic";
constexpr char kFilterParallelizationOpt[] = "filter_parallelization"; constexpr char kFilterParallelizationOpt[] = "filter_parallelization";
constexpr char kWarmStartOpt[] = "warm_start";
void DefaultOptimizationGraphRewrites( void DefaultOptimizationGraphRewrites(
const Options& options, absl::flat_hash_set<tstring>* optimization_enabled, const Options& options, absl::flat_hash_set<tstring>* optimization_enabled,
@ -213,6 +214,14 @@ void DefaultOptimizationGraphRewrites(
optimization_disabled->insert(kInjectPrefetchOpt); optimization_disabled->insert(kInjectPrefetchOpt);
} }
} }
if (optimization_options.optional_warm_start_case() ==
OptimizationOptions::kWarmStart) {
if (optimization_options.warm_start()) {
optimization_enabled->insert(kWarmStartOpt);
} else {
optimization_disabled->insert(kWarmStartOpt);
}
}
} }
// Returns whether an op has been allowlisted as stateless. Uses a heuristic to // Returns whether an op has been allowlisted as stateless. Uses a heuristic to

View File

@ -617,15 +617,17 @@ GetOptimizationsTestCase GetOptimizationTestCase4() {
options.mutable_optimization_options()->set_parallel_batch(true); options.mutable_optimization_options()->set_parallel_batch(true);
options.mutable_optimization_options()->set_shuffle_and_repeat_fusion(true); options.mutable_optimization_options()->set_shuffle_and_repeat_fusion(true);
options.mutable_optimization_options()->set_inject_prefetch(true); options.mutable_optimization_options()->set_inject_prefetch(true);
options.mutable_optimization_options()->set_warm_start(true);
options.set_slack(true); options.set_slack(true);
return {options, return {
/*expected_enabled=*/ options,
{"filter_fusion", "filter_parallelization", "make_sloppy", /*expected_enabled=*/
"map_and_batch_fusion", "map_and_filter_fusion", "map_fusion", {"filter_fusion", "filter_parallelization", "make_sloppy",
"map_parallelization", "noop_elimination", "parallel_batch", "map_and_batch_fusion", "map_and_filter_fusion", "map_fusion",
"shuffle_and_repeat_fusion", "slack", "inject_prefetch"}, "map_parallelization", "noop_elimination", "parallel_batch",
/*expected_disabled=*/{}, "shuffle_and_repeat_fusion", "slack", "inject_prefetch", "warm_start"},
/*expected_default=*/{}}; /*expected_disabled=*/{},
/*expected_default=*/{}};
} }
class GetOptimizationsTest class GetOptimizationsTest

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/data/dataset_utils.h"
#include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/data/rewrite_utils.h" #include "tensorflow/core/data/rewrite_utils.h"
#include "tensorflow/core/framework/dataset_options.pb.h"
#include "tensorflow/core/framework/model.pb.h" #include "tensorflow/core/framework/model.pb.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/host_info.h"
@ -47,6 +48,7 @@ constexpr char kPrivateThreadpoolSize[] = "threadpool_size";
constexpr char kRamBudget[] = "ram_budget_megabytes"; constexpr char kRamBudget[] = "ram_budget_megabytes";
constexpr char kRamUsage[] = "ram_usage_megabytes"; constexpr char kRamUsage[] = "ram_usage_megabytes";
constexpr char kMaxBufferBytes[] = "max_buffered_megabytes"; constexpr char kMaxBufferBytes[] = "max_buffered_megabytes";
constexpr char kWarmStart[] = "warm_start";
// If value `x` matches `y`, returns default value `z`. Otherwise, return `x`. // If value `x` matches `y`, returns default value `z`. Otherwise, return `x`.
inline int64_t value_or_default(int64_t x, int64_t y, int64_t z) { inline int64_t value_or_default(int64_t x, int64_t y, int64_t z) {
@ -83,7 +85,7 @@ void SetRootDatasetParams(const Options& options, RootDataset::Params* params) {
} }
} }
void AddTraceMetadata(const RootDataset::Params& params, void AddTraceMetadata(const RootDataset::Params& params, const Options& options,
TraceMeMetadata* trace_metadata) { TraceMeMetadata* trace_metadata) {
if (params.autotune) { if (params.autotune) {
trace_metadata->push_back(std::make_pair( trace_metadata->push_back(std::make_pair(
@ -115,6 +117,9 @@ void AddTraceMetadata(const RootDataset::Params& params,
trace_metadata->push_back( trace_metadata->push_back(
std::make_pair(kExperiments, absl::StrJoin(experiments, " "))); std::make_pair(kExperiments, absl::StrJoin(experiments, " ")));
} }
trace_metadata->push_back(std::make_pair(
kWarmStart,
options.optimization_options().warm_start() ? "true" : "false"));
} }
} // namespace } // namespace
@ -307,7 +312,7 @@ RootDataset::RootDataset(const DatasetBase* input, const Params& params)
name_utils::OpName(kDatasetType)})), name_utils::OpName(kDatasetType)})),
input_(input), input_(input),
params_(std::move(params)) { params_(std::move(params)) {
AddTraceMetadata(params_, &traceme_metadata_); AddTraceMetadata(params_, input_->options(), &traceme_metadata_);
} }
RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input, RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input,
@ -317,7 +322,7 @@ RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input,
params_(std::move(params)) { params_(std::move(params)) {
owned_input_ = std::move(input); owned_input_ = std::move(input);
input_ = owned_input_.get(); input_ = owned_input_.get();
AddTraceMetadata(params_, &traceme_metadata_); AddTraceMetadata(params_, input_->options(), &traceme_metadata_);
} }
RootDataset::~RootDataset() = default; RootDataset::~RootDataset() = default;

View File

@ -750,7 +750,8 @@ class IteratorContext {
symbolic_checkpoint(ctx->symbolic_checkpoint()), symbolic_checkpoint(ctx->symbolic_checkpoint()),
thread_factory(ctx->thread_factory()), thread_factory(ctx->thread_factory()),
thread_pool(ctx->thread_pool()), thread_pool(ctx->thread_pool()),
id_registry(ctx->id_registry()) {} id_registry(ctx->id_registry()),
warm_start(ctx->warm_start()) {}
explicit Params(OpKernelContext* ctx) explicit Params(OpKernelContext* ctx)
: collective_executor(ctx->collective_executor()), : collective_executor(ctx->collective_executor()),
@ -844,6 +845,11 @@ class IteratorContext {
std::shared_ptr<MemoryCheckpoint::IdRegistry> id_registry = std::shared_ptr<MemoryCheckpoint::IdRegistry> id_registry =
std::make_shared<MemoryCheckpoint::IdRegistry>(); std::make_shared<MemoryCheckpoint::IdRegistry>();
// If `true` background threads of asynchronous operations are started when
// the iterator is created. Otherwise, they are started upon first `GetNext`
// request. Default value is set to false to ensure backward compatibility.
bool warm_start = false;
}; };
explicit IteratorContext(IteratorContext* ctx) explicit IteratorContext(IteratorContext* ctx)
@ -922,6 +928,8 @@ class IteratorContext {
thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; } thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
bool warm_start() { return params_.warm_start; }
std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name, std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
int num_threads) { int num_threads) {
if (params_.thread_pool) { if (params_.thread_pool) {

View File

@ -82,7 +82,7 @@ message DistributeOptions {
} }
} }
// next: 20 // next: 21
message OptimizationOptions { message OptimizationOptions {
// Whether to apply default graph optimizations. If False, only graph // Whether to apply default graph optimizations. If False, only graph
// optimizations that have been explicitly enabled will be applied. // optimizations that have been explicitly enabled will be applied.
@ -149,6 +149,11 @@ message OptimizationOptions {
oneof optional_inject_prefetch { oneof optional_inject_prefetch {
bool inject_prefetch = 19; bool inject_prefetch = 19;
} }
// Whether to start background threads of asynchronous transformations upon
// iterator creation (as opposed to upon first call to `GetNext`).
oneof optional_warm_start {
bool warm_start = 20;
}
} }
// next: 3 // next: 3

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.h" #include "tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.h"
#include <atomic> #include <atomic>
#include <functional>
#include <utility> #include <utility>
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
@ -233,8 +234,12 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
&iter_ctx, this, prefix(), &input_impl_)); &iter_ctx, this, prefix(), &input_impl_));
ctx->MergeCheckpoint(iter_ctx.checkpoint()); ctx->MergeCheckpoint(iter_ctx.checkpoint());
return dataset()->captured_func_->Instantiate( TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_); ctx, &instantiated_captured_func_));
if (ctx->warm_start() && !ctx->is_restoring()) {
EnsureThreadsStarted(ctx);
}
return OkStatus();
} }
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,
@ -243,7 +248,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
std::shared_ptr<BatchResult> result; std::shared_ptr<BatchResult> result;
{ {
mutex_lock l(*mu_); mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx); EnsureThreadsStarted(ctx);
while (!cancelled_ && (batch_results_.empty() || while (!cancelled_ && (batch_results_.empty() ||
batch_results_.front()->num_calls > 0)) { batch_results_.front()->num_calls > 0)) {
++waiting_; ++waiting_;
@ -316,6 +321,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
Status RestoreInternal(IteratorContext* ctx, Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override { IteratorStateReader* reader) override {
mutex_lock l(*mu_); mutex_lock l(*mu_);
DCHECK(!runner_thread_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name(kCallCounter), &call_counter_)); reader->ReadScalar(full_name(kCallCounter), &call_counter_));
@ -326,6 +332,9 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
for (int i = 0; i < batch_results_size; ++i) { for (int i = 0; i < batch_results_size; ++i) {
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i)); TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
} }
if (ctx->warm_start()) {
EnsureThreadsStarted(ctx);
}
return OkStatus(); return OkStatus();
} }
@ -510,13 +519,13 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
} }
} }
void EnsureRunnerThreadStarted(IteratorContext* ctx) void EnsureThreadsStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) { if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx); auto new_ctx = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread( runner_thread_ =
kTFDataMapAndBatch, ctx->StartThread(kTFDataMapAndBatch,
std::bind(&Iterator::RunnerThread, this, ctx_copy)); std::bind(&Iterator::RunnerThread, this, new_ctx));
} }
} }

View File

@ -126,6 +126,7 @@ Status IteratorResource::GetNext(OpKernelContext* ctx,
params.thread_factory = unbounded_thread_pool_.get_thread_factory(); params.thread_factory = unbounded_thread_pool_.get_thread_factory();
params.thread_pool = &unbounded_thread_pool_; params.thread_pool = &unbounded_thread_pool_;
params.id_registry = captured_state->id_registry(); params.id_registry = captured_state->id_registry();
params.warm_start = dataset->options().optimization_options().warm_start();
std::function<void()> deregister_fn; std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(RegisterCancellationCallback( TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(), ctx->cancellation_manager(),
@ -248,6 +249,7 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
params.thread_factory = unbounded_thread_pool_.get_thread_factory(); params.thread_factory = unbounded_thread_pool_.get_thread_factory();
params.thread_pool = &unbounded_thread_pool_; params.thread_pool = &unbounded_thread_pool_;
params.id_registry = new_state->id_registry(); params.id_registry = new_state->id_registry();
params.warm_start = dataset->options().optimization_options().warm_start();
std::function<void()> deregister_fn; std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(RegisterCancellationCallback( TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(), ctx->cancellation_manager(),

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/parallel_batch_dataset_op.h" #include "tensorflow/core/kernels/data/parallel_batch_dataset_op.h"
#include <algorithm> #include <algorithm>
#include <functional>
#include <memory> #include <memory>
#include <utility> #include <utility>
@ -219,8 +220,12 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_)); [this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
IteratorContext::Params params(ctx); IteratorContext::Params params(ctx);
params.cancellation_manager = cancellation_manager_.get(); params.cancellation_manager = cancellation_manager_.get();
return dataset()->input_->MakeIterator(IteratorContext(params), this, TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
prefix(), &input_impl_); IteratorContext(params), this, prefix(), &input_impl_));
if (ctx->warm_start() && !ctx->is_restoring()) {
EnsureThreadsStarted(ctx);
}
return OkStatus();
} }
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,
@ -229,7 +234,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
std::shared_ptr<BatchResult> result; std::shared_ptr<BatchResult> result;
{ {
mutex_lock l(*mu_); mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx); EnsureThreadsStarted(ctx);
while (ShouldWait(&result)) { while (ShouldWait(&result)) {
RecordStop(ctx); RecordStop(ctx);
cond_var_->wait(l); cond_var_->wait(l);
@ -289,6 +294,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
Status RestoreInternal(IteratorContext* ctx, Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override { IteratorStateReader* reader) override {
mutex_lock l(*mu_); mutex_lock l(*mu_);
DCHECK(!runner_thread_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64_t batch_results_size; int64_t batch_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBatchResultsSize), TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBatchResultsSize),
@ -297,6 +303,9 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
for (int i = 0; i < batch_results_size; ++i) { for (int i = 0; i < batch_results_size; ++i) {
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i)); TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
} }
if (ctx->warm_start()) {
EnsureThreadsStarted(ctx);
}
return OkStatus(); return OkStatus();
} }
@ -432,13 +441,13 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
} }
} }
void EnsureRunnerThreadStarted(IteratorContext* ctx) void EnsureThreadsStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) { if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx); auto new_ctx = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread( runner_thread_ =
kTFDataParallelBatch, ctx->StartThread(kTFDataParallelBatch,
std::bind(&Iterator::RunnerThread, this, ctx_copy)); std::bind(&Iterator::RunnerThread, this, new_ctx));
} }
} }

View File

@ -395,8 +395,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
params.cancellation_manager = cancellation_manager_.get(); params.cancellation_manager = cancellation_manager_.get();
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
IteratorContext(params), this, prefix(), &input_impl_)); IteratorContext(params), this, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate( TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_); ctx, &instantiated_captured_func_));
if (ctx->warm_start() && !ctx->is_restoring()) {
EnsureInitialElementsCreated();
EnsureThreadsStarted();
}
return OkStatus();
} }
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,
@ -515,7 +520,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
IteratorStateReader* reader) override { IteratorStateReader* reader) override {
{ {
mutex_lock l(*mu_); mutex_lock l(*mu_);
DCHECK(!threads_initialized_); DCHECK(!threads_started_);
DCHECK(!initial_elements_created_); DCHECK(!initial_elements_created_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -548,6 +553,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
!current_elements_[last_valid_current_element_]) { !current_elements_[last_valid_current_element_]) {
last_valid_current_element_--; last_valid_current_element_--;
} }
if (ctx->warm_start()) {
EnsureInitialElementsCreated();
EnsureThreadsStarted();
}
VLOG(2) << "Parallel interleave iterator restored"; VLOG(2) << "Parallel interleave iterator restored";
VLOG(4) << "State after restore:\n" << DebugString(); VLOG(4) << "State after restore:\n" << DebugString();
return OkStatus(); return OkStatus();
@ -690,14 +699,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
} }
void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!threads_initialized_) { if (!threads_started_) {
IncrementOutstandingThreads(); IncrementOutstandingThreads();
thread_pool_->Schedule([this]() { WorkerManagerThread(); }); thread_pool_->Schedule([this]() { WorkerManagerThread(); });
if (ctx_->stats_aggregator()) { if (ctx_->stats_aggregator()) {
IncrementOutstandingThreads(); IncrementOutstandingThreads();
thread_pool_->Schedule([this]() { StatsThread(); }); thread_pool_->Schedule([this]() { StatsThread(); });
} }
threads_initialized_ = true; threads_started_ = true;
} }
} }
@ -1554,8 +1563,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// Identifies whether the current_elements_ vector has been initialized. // Identifies whether the current_elements_ vector has been initialized.
bool initial_elements_created_ TF_GUARDED_BY(mu_) = false; bool initial_elements_created_ TF_GUARDED_BY(mu_) = false;
// Identifies whether the element threads have been initialized. // Identifies whether the element threads have been started.
bool threads_initialized_ TF_GUARDED_BY(mu_) = false; bool threads_started_ TF_GUARDED_BY(mu_) = false;
// Used for coordination between the main thread, the manager threads, and // Used for coordination between the main thread, the manager threads, and
// the worker threads. // the worker threads.

View File

@ -276,8 +276,12 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
&iter_ctx, this, prefix(), &input_impl_)); &iter_ctx, this, prefix(), &input_impl_));
ctx->MergeCheckpoint(iter_ctx.checkpoint()); ctx->MergeCheckpoint(iter_ctx.checkpoint());
return dataset()->captured_func_->Instantiate( TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
ctx, &instantiated_captured_func_); ctx, &instantiated_captured_func_));
if (ctx->warm_start() && !ctx->is_restoring()) {
EnsureThreadsStarted(ctx);
}
return OkStatus();
} }
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,

View File

@ -181,6 +181,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
IteratorContext iter_ctx(params); IteratorContext iter_ctx(params);
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
&iter_ctx, this, prefix(), &input_impl_)); &iter_ctx, this, prefix(), &input_impl_));
if (ctx->warm_start() && !ctx->is_restoring()) {
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
}
ctx->MergeCheckpoint(iter_ctx.checkpoint()); ctx->MergeCheckpoint(iter_ctx.checkpoint());
return OkStatus(); return OkStatus();
} }
@ -191,7 +194,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
const auto& stats_aggregator = ctx->stats_aggregator(); const auto& stats_aggregator = ctx->stats_aggregator();
{ {
mutex_lock l(*mu_); mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx)); TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
// Wait until the next element in the buffer has been // Wait until the next element in the buffer has been
// produced, or we are shutting down. // produced, or we are shutting down.
while (buffer_.empty() && !prefetch_thread_finished_ && while (buffer_.empty() && !prefetch_thread_finished_ &&
@ -283,6 +286,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
IteratorStateReader* reader) override { IteratorStateReader* reader) override {
mutex_lock input_l(input_mu_); mutex_lock input_l(input_mu_);
mutex_lock l(*mu_); mutex_lock l(*mu_);
DCHECK(!prefetch_thread_);
DCHECK(buffer_.empty()); DCHECK(buffer_.empty());
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
size_t buffer_size; size_t buffer_size;
@ -315,6 +319,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
} }
RecordBufferEnqueue(ctx, buffer_element.value); RecordBufferEnqueue(ctx, buffer_element.value);
} }
if (ctx->warm_start()) {
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
}
return OkStatus(); return OkStatus();
} }
@ -458,7 +465,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
return s; return s;
} }
Status EnsurePrefetchThreadStarted(IteratorContext* ctx) Status EnsureThreadsStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!prefetch_thread_) { if (!prefetch_thread_) {
std::shared_ptr<IteratorContext> new_ctx = std::shared_ptr<IteratorContext> new_ctx =

View File

@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Tests for the static tf.data optimizations.""" """Tests for the static tf.data optimizations."""
import functools import functools
import time
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
@ -32,6 +33,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -245,6 +247,7 @@ class OptimizationTest(test_base.DatasetTestBase, parameterized.TestCase):
options.experimental_optimization.apply_default_optimizations = False options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.noop_elimination = True options.experimental_optimization.noop_elimination = True
options.experimental_optimization.map_and_batch_fusion = True options.experimental_optimization.map_and_batch_fusion = True
options.experimental_optimization.warm_start = False
optimized_dataset = unoptimized_dataset.with_options(options) optimized_dataset = unoptimized_dataset.with_options(options)
optimized_it = dataset_ops.make_initializable_iterator(optimized_dataset) optimized_it = dataset_ops.make_initializable_iterator(optimized_dataset)
@ -266,6 +269,40 @@ class OptimizationTest(test_base.DatasetTestBase, parameterized.TestCase):
except errors.OutOfRangeError: except errors.OutOfRangeError:
break break
@combinations.generate(
combinations.times(
test_base.eager_only_combinations(),
combinations.combine(warm_start=[True, False]),
)
)
def testOptimizationWarmStart(self, warm_start):
dataset = dataset_ops.Dataset.range(10)
counter = variables.Variable(0)
def update_counter(x):
counter.assign_add(1)
return x
options = options_lib.Options()
options.experimental_optimization.apply_default_optimizations = False
if warm_start:
options.experimental_optimization.warm_start = True
else:
options.experimental_optimization.warm_start = False
dataset = dataset.with_options(options)
dataset = dataset.map(update_counter).prefetch(10)
unused_iter = iter(dataset)
if warm_start:
for sleep_time_secs in [0.1, 0.2, 0.5, 2, 5, 10]:
if counter.numpy() == 0:
time.sleep(sleep_time_secs)
else:
break
self.assertGreater(counter.numpy(), 0)
else:
self.assertEqual(counter.numpy(), 0)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -563,7 +563,7 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
job_gc_check_interval_ms=50, job_gc_check_interval_ms=50,
job_gc_timeout_ms=20, job_gc_timeout_ms=20,
data_transfer_protocol=self._get_data_transfer_protocol()) data_transfer_protocol=self._get_data_transfer_protocol())
num_elements = 10 num_elements = 1000
it1 = iter( it1 = iter(
self.make_distributed_range_dataset( self.make_distributed_range_dataset(
num_elements, num_elements,

View File

@ -21,14 +21,21 @@ from tensorflow.python.data.experimental.kernel_tests.service import test_base a
from tensorflow.python.data.experimental.ops import data_service_ops from tensorflow.python.data.experimental.ops import data_service_ops
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import test_mode
from tensorflow.python.framework import combinations from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
import multiprocessing import multiprocessing
class LocalWorkersTest(data_service_test_base.TestBase, parameterized.TestCase): class LocalWorkersTest(data_service_test_base.TestBase, parameterized.TestCase):
"""Tests reading from local workers if `target_workers` is `local`.""" """Tests reading from local workers if `target_workers` is `local`."""
def setUp(self):
super().setUp()
# TODO(b/268586701): Enable `warm_start` for `local_workers_test`.
test_mode.toggle_test_mode(False)
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testOneLocalWorker(self): def testOneLocalWorker(self):
cluster = multi_process_cluster.MultiProcessCluster( cluster = multi_process_cluster.MultiProcessCluster(

View File

@ -23,6 +23,7 @@ from tensorflow.python.data.experimental.kernel_tests.service import test_base a
from tensorflow.python.data.experimental.ops import distributed_save_op from tensorflow.python.data.experimental.ops import distributed_save_op
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import test_mode
from tensorflow.python.framework import combinations from tensorflow.python.framework import combinations
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -62,6 +63,8 @@ class SnapshotFtTest(data_service_test_base.TestBase, parameterized.TestCase):
tempfile.mkdtemp(dir=self.get_temp_dir()), tempfile.mkdtemp(dir=self.get_temp_dir()),
"snapshot_ft_test", "snapshot_ft_test",
) )
# TODO(b/268586560): Enable `warm_start` for `snapshot_ft_test`.
test_mode.toggle_test_mode(False)
# This "manual" setup function is needed due to some bad interaction between # This "manual" setup function is needed due to some bad interaction between
# `setUp` and `combinations` that causes the dataset to be out-of-scope. # `setUp` and `combinations` that causes the dataset to be out-of-scope.

View File

@ -1102,6 +1102,7 @@ py_library(
"//tensorflow/python/data/experimental/ops:lookup_ops", "//tensorflow/python/data/experimental/ops:lookup_ops",
"//tensorflow/python/data/experimental/ops:random_access", "//tensorflow/python/data/experimental/ops:random_access",
"//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:test_mode",
"//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:nest",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/ops/ragged", "//tensorflow/python/ops/ragged",

View File

@ -149,6 +149,7 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
options.experimental_optimization.noop_elimination = True options.experimental_optimization.noop_elimination = True
options.experimental_optimization.parallel_batch = True options.experimental_optimization.parallel_batch = True
options.experimental_optimization.shuffle_and_repeat_fusion = True options.experimental_optimization.shuffle_and_repeat_fusion = True
options.experimental_optimization.warm_start = True
options.experimental_slack = True options.experimental_slack = True
options.threading.max_intra_op_parallelism = 30 options.threading.max_intra_op_parallelism = 30
options.threading.private_threadpool_size = 40 options.threading.private_threadpool_size = 40
@ -177,6 +178,7 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset_options_pb2.DistributeOptions()) dataset_options_pb2.DistributeOptions())
expected_pb.optimization_options.CopyFrom( expected_pb.optimization_options.CopyFrom(
dataset_options_pb2.OptimizationOptions()) dataset_options_pb2.OptimizationOptions())
expected_pb.optimization_options.warm_start = True
expected_pb.threading_options.CopyFrom( expected_pb.threading_options.CopyFrom(
dataset_options_pb2.ThreadingOptions()) dataset_options_pb2.ThreadingOptions())
self.assertProtoEquals(expected_pb, result) self.assertProtoEquals(expected_pb, result)

View File

@ -20,6 +20,7 @@ import re
from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops
from tensorflow.python.data.experimental.ops import random_access from tensorflow.python.data.experimental.ops import random_access
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import test_mode
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure from tensorflow.python.data.util import structure
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -72,6 +73,10 @@ def v2_eager_only_combinations():
class DatasetTestBase(test.TestCase): class DatasetTestBase(test.TestCase):
"""Base class for dataset tests.""" """Base class for dataset tests."""
def setUp(self):
super().setUp()
test_mode.toggle_test_mode(True)
def assert_op_cancelled(self, op): def assert_op_cancelled(self, op):
with self.assertRaises(errors.CancelledError): with self.assertRaises(errors.CancelledError):
self.evaluate(op) self.evaluate(op)

View File

@ -23,6 +23,11 @@ py_library(
srcs = ["debug_mode.py"], srcs = ["debug_mode.py"],
) )
py_library(
name = "test_mode",
srcs = ["test_mode.py"],
)
py_library( py_library(
name = "batch_op", name = "batch_op",
srcs = ["batch_op.py"], srcs = ["batch_op.py"],
@ -416,6 +421,7 @@ py_library(
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python/data/ops:test_mode",
"//tensorflow/python/data/util:options", "//tensorflow/python/data/util:options",
], ],
) )

View File

@ -20,6 +20,7 @@ from absl import logging
from tensorflow.core.framework import dataset_options_pb2 from tensorflow.core.framework import dataset_options_pb2
from tensorflow.core.framework import model_pb2 from tensorflow.core.framework import model_pb2
from tensorflow.python.data.ops import test_mode
from tensorflow.python.data.util import options as options_lib from tensorflow.python.data.util import options as options_lib
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -387,6 +388,20 @@ class OptimizationOptions(options_lib.OptionsBase):
docstring="Whether to fuse shuffle and repeat transformations. If None, " docstring="Whether to fuse shuffle and repeat transformations. If None, "
"defaults to True.") "defaults to True.")
warm_start = options_lib.create_option(
name="warm_start",
ty=bool,
docstring=(
"Whether to start background threads of asynchronous transformations"
" upon iterator creation (as opposed to upon first call to"
" `GetNext`). If None, defaults to False. It should be noted that"
" this possibly improves the latency of the initial 'GetNext' call at"
" the expense of requiring more memory to hold prefetched elements"
" between the time of iterator construction and usage."
),
default_factory=lambda: True if test_mode.TEST_MODE else False,
)
def _to_proto(self): def _to_proto(self):
pb = dataset_options_pb2.OptimizationOptions() pb = dataset_options_pb2.OptimizationOptions()
if self.apply_default_optimizations is not None: if self.apply_default_optimizations is not None:
@ -411,6 +426,8 @@ class OptimizationOptions(options_lib.OptionsBase):
pb.parallel_batch = self.parallel_batch pb.parallel_batch = self.parallel_batch
if self.shuffle_and_repeat_fusion is not None: if self.shuffle_and_repeat_fusion is not None:
pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion
if self.warm_start is not None:
pb.warm_start = self.warm_start
return pb return pb
def _from_proto(self, pb): def _from_proto(self, pb):
@ -436,6 +453,8 @@ class OptimizationOptions(options_lib.OptionsBase):
self.parallel_batch = pb.parallel_batch self.parallel_batch = pb.parallel_batch
if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None: if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
if pb.WhichOneof("optional_warm_start") is not None:
self.warm_start = pb.warm_start
def _set_mutable(self, mutable): def _set_mutable(self, mutable):
"""Change the mutability value to `mutable` on this options and children.""" """Change the mutability value to `mutable` on this options and children."""

View File

@ -0,0 +1,32 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python test mode enabler.
Enables test mode for tf.data.
The test mode can be used to set up custom values for features and
experiments as required in the unit tests.
For example, if `warm_start` feature needs to be enabled exclusively for the
unit tests, the tests can enable the test mode using `toggle_test_mode` and
the default value of `warm_start` can be set as per the value of `TEST_MODE`.
"""
TEST_MODE = False
def toggle_test_mode(test_mode):
global TEST_MODE
TEST_MODE = test_mode

View File

@ -47,6 +47,10 @@ tf_class {
name: "shuffle_and_repeat_fusion" name: "shuffle_and_repeat_fusion"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "warm_start"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -47,6 +47,10 @@ tf_class {
name: "shuffle_and_repeat_fusion" name: "shuffle_and_repeat_fusion"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "warm_start"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"