mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[tf.data] Implement warm_start feature for all the asynchronous operations.
PiperOrigin-RevId: 509533594
This commit is contained in:
parent
dee9915773
commit
d3509a44ca
10
RELEASE.md
10
RELEASE.md
|
|
@ -202,7 +202,15 @@ This release contains contributions from many people at Google, as well as:
|
|||
`rerandomize_each_iteration=True`, the `sample_from_datasets()`
|
||||
operation will use a different (deterministic) sequence of numbers every
|
||||
epoch.
|
||||
|
||||
* Added a new field, `warm_start`, to
|
||||
`tf.data.experimental.OptimizationOptions`. If it is set to `True`,
|
||||
tf.data will start background threads of asynchronous
|
||||
transformations upon iterator creation (as opposed to upon first call
|
||||
to `GetNext`). To enable this behavior, set `warm_start=True` in
|
||||
`tf.data.experimental.OptimizationOptions`. It should be noted that this
|
||||
possibly improves the latency of the initial 'GetNext' call at the
|
||||
expense of requiring more memory to hold prefetched elements between
|
||||
the time of iterator construction and usage.
|
||||
* `tf.test`:
|
||||
|
||||
* Added `tf.test.experimental.sync_devices`, which is useful for
|
||||
|
|
|
|||
|
|
@ -96,6 +96,7 @@ constexpr char kSlackOpt[] = "slack";
|
|||
constexpr char kSlackPeriodOpt[] = "slack_period";
|
||||
constexpr char kMakeDeterministicOpt[] = "make_deterministic";
|
||||
constexpr char kFilterParallelizationOpt[] = "filter_parallelization";
|
||||
constexpr char kWarmStartOpt[] = "warm_start";
|
||||
|
||||
void DefaultOptimizationGraphRewrites(
|
||||
const Options& options, absl::flat_hash_set<tstring>* optimization_enabled,
|
||||
|
|
@ -213,6 +214,14 @@ void DefaultOptimizationGraphRewrites(
|
|||
optimization_disabled->insert(kInjectPrefetchOpt);
|
||||
}
|
||||
}
|
||||
if (optimization_options.optional_warm_start_case() ==
|
||||
OptimizationOptions::kWarmStart) {
|
||||
if (optimization_options.warm_start()) {
|
||||
optimization_enabled->insert(kWarmStartOpt);
|
||||
} else {
|
||||
optimization_disabled->insert(kWarmStartOpt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns whether an op has been allowlisted as stateless. Uses a heuristic to
|
||||
|
|
|
|||
|
|
@ -617,13 +617,15 @@ GetOptimizationsTestCase GetOptimizationTestCase4() {
|
|||
options.mutable_optimization_options()->set_parallel_batch(true);
|
||||
options.mutable_optimization_options()->set_shuffle_and_repeat_fusion(true);
|
||||
options.mutable_optimization_options()->set_inject_prefetch(true);
|
||||
options.mutable_optimization_options()->set_warm_start(true);
|
||||
options.set_slack(true);
|
||||
return {options,
|
||||
return {
|
||||
options,
|
||||
/*expected_enabled=*/
|
||||
{"filter_fusion", "filter_parallelization", "make_sloppy",
|
||||
"map_and_batch_fusion", "map_and_filter_fusion", "map_fusion",
|
||||
"map_parallelization", "noop_elimination", "parallel_batch",
|
||||
"shuffle_and_repeat_fusion", "slack", "inject_prefetch"},
|
||||
"shuffle_and_repeat_fusion", "slack", "inject_prefetch", "warm_start"},
|
||||
/*expected_disabled=*/{},
|
||||
/*expected_default=*/{}};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/data/dataset_utils.h"
|
||||
#include "tensorflow/core/data/name_utils.h"
|
||||
#include "tensorflow/core/data/rewrite_utils.h"
|
||||
#include "tensorflow/core/framework/dataset_options.pb.h"
|
||||
#include "tensorflow/core/framework/model.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
|
|
@ -47,6 +48,7 @@ constexpr char kPrivateThreadpoolSize[] = "threadpool_size";
|
|||
constexpr char kRamBudget[] = "ram_budget_megabytes";
|
||||
constexpr char kRamUsage[] = "ram_usage_megabytes";
|
||||
constexpr char kMaxBufferBytes[] = "max_buffered_megabytes";
|
||||
constexpr char kWarmStart[] = "warm_start";
|
||||
|
||||
// If value `x` matches `y`, returns default value `z`. Otherwise, return `x`.
|
||||
inline int64_t value_or_default(int64_t x, int64_t y, int64_t z) {
|
||||
|
|
@ -83,7 +85,7 @@ void SetRootDatasetParams(const Options& options, RootDataset::Params* params) {
|
|||
}
|
||||
}
|
||||
|
||||
void AddTraceMetadata(const RootDataset::Params& params,
|
||||
void AddTraceMetadata(const RootDataset::Params& params, const Options& options,
|
||||
TraceMeMetadata* trace_metadata) {
|
||||
if (params.autotune) {
|
||||
trace_metadata->push_back(std::make_pair(
|
||||
|
|
@ -115,6 +117,9 @@ void AddTraceMetadata(const RootDataset::Params& params,
|
|||
trace_metadata->push_back(
|
||||
std::make_pair(kExperiments, absl::StrJoin(experiments, " ")));
|
||||
}
|
||||
trace_metadata->push_back(std::make_pair(
|
||||
kWarmStart,
|
||||
options.optimization_options().warm_start() ? "true" : "false"));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
@ -307,7 +312,7 @@ RootDataset::RootDataset(const DatasetBase* input, const Params& params)
|
|||
name_utils::OpName(kDatasetType)})),
|
||||
input_(input),
|
||||
params_(std::move(params)) {
|
||||
AddTraceMetadata(params_, &traceme_metadata_);
|
||||
AddTraceMetadata(params_, input_->options(), &traceme_metadata_);
|
||||
}
|
||||
|
||||
RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input,
|
||||
|
|
@ -317,7 +322,7 @@ RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input,
|
|||
params_(std::move(params)) {
|
||||
owned_input_ = std::move(input);
|
||||
input_ = owned_input_.get();
|
||||
AddTraceMetadata(params_, &traceme_metadata_);
|
||||
AddTraceMetadata(params_, input_->options(), &traceme_metadata_);
|
||||
}
|
||||
|
||||
RootDataset::~RootDataset() = default;
|
||||
|
|
|
|||
|
|
@ -750,7 +750,8 @@ class IteratorContext {
|
|||
symbolic_checkpoint(ctx->symbolic_checkpoint()),
|
||||
thread_factory(ctx->thread_factory()),
|
||||
thread_pool(ctx->thread_pool()),
|
||||
id_registry(ctx->id_registry()) {}
|
||||
id_registry(ctx->id_registry()),
|
||||
warm_start(ctx->warm_start()) {}
|
||||
|
||||
explicit Params(OpKernelContext* ctx)
|
||||
: collective_executor(ctx->collective_executor()),
|
||||
|
|
@ -844,6 +845,11 @@ class IteratorContext {
|
|||
|
||||
std::shared_ptr<MemoryCheckpoint::IdRegistry> id_registry =
|
||||
std::make_shared<MemoryCheckpoint::IdRegistry>();
|
||||
|
||||
// If `true` background threads of asynchronous operations are started when
|
||||
// the iterator is created. Otherwise, they are started upon first `GetNext`
|
||||
// request. Default value is set to false to ensure backward compatibility.
|
||||
bool warm_start = false;
|
||||
};
|
||||
|
||||
explicit IteratorContext(IteratorContext* ctx)
|
||||
|
|
@ -922,6 +928,8 @@ class IteratorContext {
|
|||
|
||||
thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
|
||||
|
||||
bool warm_start() { return params_.warm_start; }
|
||||
|
||||
std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
|
||||
int num_threads) {
|
||||
if (params_.thread_pool) {
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ message DistributeOptions {
|
|||
}
|
||||
}
|
||||
|
||||
// next: 20
|
||||
// next: 21
|
||||
message OptimizationOptions {
|
||||
// Whether to apply default graph optimizations. If False, only graph
|
||||
// optimizations that have been explicitly enabled will be applied.
|
||||
|
|
@ -149,6 +149,11 @@ message OptimizationOptions {
|
|||
oneof optional_inject_prefetch {
|
||||
bool inject_prefetch = 19;
|
||||
}
|
||||
// Whether to start background threads of asynchronous transformations upon
|
||||
// iterator creation (as opposed to upon first call to `GetNext`).
|
||||
oneof optional_warm_start {
|
||||
bool warm_start = 20;
|
||||
}
|
||||
}
|
||||
|
||||
// next: 3
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
|
|
@ -233,8 +234,12 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
|||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
&iter_ctx, this, prefix(), &input_impl_));
|
||||
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
||||
return dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_);
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_));
|
||||
if (ctx->warm_start() && !ctx->is_restoring()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
|
|
@ -243,7 +248,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
|||
std::shared_ptr<BatchResult> result;
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
EnsureRunnerThreadStarted(ctx);
|
||||
EnsureThreadsStarted(ctx);
|
||||
while (!cancelled_ && (batch_results_.empty() ||
|
||||
batch_results_.front()->num_calls > 0)) {
|
||||
++waiting_;
|
||||
|
|
@ -316,6 +321,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
|||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(*mu_);
|
||||
DCHECK(!runner_thread_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kCallCounter), &call_counter_));
|
||||
|
|
@ -326,6 +332,9 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
|||
for (int i = 0; i < batch_results_size; ++i) {
|
||||
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
|
||||
}
|
||||
if (ctx->warm_start()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -510,13 +519,13 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
|||
}
|
||||
}
|
||||
|
||||
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!runner_thread_) {
|
||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ = ctx->StartThread(
|
||||
kTFDataMapAndBatch,
|
||||
std::bind(&Iterator::RunnerThread, this, ctx_copy));
|
||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ =
|
||||
ctx->StartThread(kTFDataMapAndBatch,
|
||||
std::bind(&Iterator::RunnerThread, this, new_ctx));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -126,6 +126,7 @@ Status IteratorResource::GetNext(OpKernelContext* ctx,
|
|||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
params.thread_pool = &unbounded_thread_pool_;
|
||||
params.id_registry = captured_state->id_registry();
|
||||
params.warm_start = dataset->options().optimization_options().warm_start();
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
|
|
@ -248,6 +249,7 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
|||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
params.thread_pool = &unbounded_thread_pool_;
|
||||
params.id_registry = new_state->id_registry();
|
||||
params.warm_start = dataset->options().optimization_options().warm_start();
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/kernels/data/parallel_batch_dataset_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
|
|
@ -219,8 +220,12 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
|||
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
|
||||
IteratorContext::Params params(ctx);
|
||||
params.cancellation_manager = cancellation_manager_.get();
|
||||
return dataset()->input_->MakeIterator(IteratorContext(params), this,
|
||||
prefix(), &input_impl_);
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
IteratorContext(params), this, prefix(), &input_impl_));
|
||||
if (ctx->warm_start() && !ctx->is_restoring()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
|
|
@ -229,7 +234,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
|||
std::shared_ptr<BatchResult> result;
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
EnsureRunnerThreadStarted(ctx);
|
||||
EnsureThreadsStarted(ctx);
|
||||
while (ShouldWait(&result)) {
|
||||
RecordStop(ctx);
|
||||
cond_var_->wait(l);
|
||||
|
|
@ -289,6 +294,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
|||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(*mu_);
|
||||
DCHECK(!runner_thread_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
int64_t batch_results_size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBatchResultsSize),
|
||||
|
|
@ -297,6 +303,9 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
|||
for (int i = 0; i < batch_results_size; ++i) {
|
||||
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
|
||||
}
|
||||
if (ctx->warm_start()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -432,13 +441,13 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
|||
}
|
||||
}
|
||||
|
||||
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
||||
void EnsureThreadsStarted(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!runner_thread_) {
|
||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ = ctx->StartThread(
|
||||
kTFDataParallelBatch,
|
||||
std::bind(&Iterator::RunnerThread, this, ctx_copy));
|
||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_ =
|
||||
ctx->StartThread(kTFDataParallelBatch,
|
||||
std::bind(&Iterator::RunnerThread, this, new_ctx));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -395,8 +395,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
params.cancellation_manager = cancellation_manager_.get();
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
IteratorContext(params), this, prefix(), &input_impl_));
|
||||
return dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_);
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_));
|
||||
if (ctx->warm_start() && !ctx->is_restoring()) {
|
||||
EnsureInitialElementsCreated();
|
||||
EnsureThreadsStarted();
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
|
|
@ -515,7 +520,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
IteratorStateReader* reader) override {
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
DCHECK(!threads_initialized_);
|
||||
DCHECK(!threads_started_);
|
||||
DCHECK(!initial_elements_created_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
|
@ -548,6 +553,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
!current_elements_[last_valid_current_element_]) {
|
||||
last_valid_current_element_--;
|
||||
}
|
||||
if (ctx->warm_start()) {
|
||||
EnsureInitialElementsCreated();
|
||||
EnsureThreadsStarted();
|
||||
}
|
||||
VLOG(2) << "Parallel interleave iterator restored";
|
||||
VLOG(4) << "State after restore:\n" << DebugString();
|
||||
return OkStatus();
|
||||
|
|
@ -690,14 +699,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
}
|
||||
|
||||
void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (!threads_initialized_) {
|
||||
if (!threads_started_) {
|
||||
IncrementOutstandingThreads();
|
||||
thread_pool_->Schedule([this]() { WorkerManagerThread(); });
|
||||
if (ctx_->stats_aggregator()) {
|
||||
IncrementOutstandingThreads();
|
||||
thread_pool_->Schedule([this]() { StatsThread(); });
|
||||
}
|
||||
threads_initialized_ = true;
|
||||
threads_started_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1554,8 +1563,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
// Identifies whether the current_elements_ vector has been initialized.
|
||||
bool initial_elements_created_ TF_GUARDED_BY(mu_) = false;
|
||||
|
||||
// Identifies whether the element threads have been initialized.
|
||||
bool threads_initialized_ TF_GUARDED_BY(mu_) = false;
|
||||
// Identifies whether the element threads have been started.
|
||||
bool threads_started_ TF_GUARDED_BY(mu_) = false;
|
||||
|
||||
// Used for coordination between the main thread, the manager threads, and
|
||||
// the worker threads.
|
||||
|
|
|
|||
|
|
@ -276,8 +276,12 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
&iter_ctx, this, prefix(), &input_impl_));
|
||||
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
||||
return dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_);
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_));
|
||||
if (ctx->warm_start() && !ctx->is_restoring()) {
|
||||
EnsureThreadsStarted(ctx);
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
|
|
|
|||
|
|
@ -181,6 +181,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||
IteratorContext iter_ctx(params);
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
&iter_ctx, this, prefix(), &input_impl_));
|
||||
if (ctx->warm_start() && !ctx->is_restoring()) {
|
||||
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
|
||||
}
|
||||
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -191,7 +194,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
|
||||
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
|
||||
// Wait until the next element in the buffer has been
|
||||
// produced, or we are shutting down.
|
||||
while (buffer_.empty() && !prefetch_thread_finished_ &&
|
||||
|
|
@ -283,6 +286,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||
IteratorStateReader* reader) override {
|
||||
mutex_lock input_l(input_mu_);
|
||||
mutex_lock l(*mu_);
|
||||
DCHECK(!prefetch_thread_);
|
||||
DCHECK(buffer_.empty());
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
size_t buffer_size;
|
||||
|
|
@ -315,6 +319,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||
}
|
||||
RecordBufferEnqueue(ctx, buffer_element.value);
|
||||
}
|
||||
if (ctx->warm_start()) {
|
||||
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -458,7 +465,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||
return s;
|
||||
}
|
||||
|
||||
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
|
||||
Status EnsureThreadsStarted(IteratorContext* ctx)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!prefetch_thread_) {
|
||||
std::shared_ptr<IteratorContext> new_ctx =
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
# ==============================================================================
|
||||
"""Tests for the static tf.data optimizations."""
|
||||
import functools
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
|
@ -32,6 +33,7 @@ from tensorflow.python.framework import ops
|
|||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
|
|
@ -245,6 +247,7 @@ class OptimizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
options.experimental_optimization.apply_default_optimizations = False
|
||||
options.experimental_optimization.noop_elimination = True
|
||||
options.experimental_optimization.map_and_batch_fusion = True
|
||||
options.experimental_optimization.warm_start = False
|
||||
optimized_dataset = unoptimized_dataset.with_options(options)
|
||||
optimized_it = dataset_ops.make_initializable_iterator(optimized_dataset)
|
||||
|
||||
|
|
@ -266,6 +269,40 @@ class OptimizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
except errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.eager_only_combinations(),
|
||||
combinations.combine(warm_start=[True, False]),
|
||||
)
|
||||
)
|
||||
def testOptimizationWarmStart(self, warm_start):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
counter = variables.Variable(0)
|
||||
|
||||
def update_counter(x):
|
||||
counter.assign_add(1)
|
||||
return x
|
||||
|
||||
options = options_lib.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
if warm_start:
|
||||
options.experimental_optimization.warm_start = True
|
||||
else:
|
||||
options.experimental_optimization.warm_start = False
|
||||
dataset = dataset.with_options(options)
|
||||
dataset = dataset.map(update_counter).prefetch(10)
|
||||
unused_iter = iter(dataset)
|
||||
|
||||
if warm_start:
|
||||
for sleep_time_secs in [0.1, 0.2, 0.5, 2, 5, 10]:
|
||||
if counter.numpy() == 0:
|
||||
time.sleep(sleep_time_secs)
|
||||
else:
|
||||
break
|
||||
self.assertGreater(counter.numpy(), 0)
|
||||
else:
|
||||
self.assertEqual(counter.numpy(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -563,7 +563,7 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||
job_gc_check_interval_ms=50,
|
||||
job_gc_timeout_ms=20,
|
||||
data_transfer_protocol=self._get_data_transfer_protocol())
|
||||
num_elements = 10
|
||||
num_elements = 1000
|
||||
it1 = iter(
|
||||
self.make_distributed_range_dataset(
|
||||
num_elements,
|
||||
|
|
|
|||
|
|
@ -21,14 +21,21 @@ from tensorflow.python.data.experimental.kernel_tests.service import test_base a
|
|||
from tensorflow.python.data.experimental.ops import data_service_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import test_mode
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
import multiprocessing
|
||||
|
||||
|
||||
class LocalWorkersTest(data_service_test_base.TestBase, parameterized.TestCase):
|
||||
"""Tests reading from local workers if `target_workers` is `local`."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# TODO(b/268586701): Enable `warm_start` for `local_workers_test`.
|
||||
test_mode.toggle_test_mode(False)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOneLocalWorker(self):
|
||||
cluster = multi_process_cluster.MultiProcessCluster(
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from tensorflow.python.data.experimental.kernel_tests.service import test_base a
|
|||
from tensorflow.python.data.experimental.ops import distributed_save_op
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import test_mode
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
|
@ -62,6 +63,8 @@ class SnapshotFtTest(data_service_test_base.TestBase, parameterized.TestCase):
|
|||
tempfile.mkdtemp(dir=self.get_temp_dir()),
|
||||
"snapshot_ft_test",
|
||||
)
|
||||
# TODO(b/268586560): Enable `warm_start` for `snapshot_ft_test`.
|
||||
test_mode.toggle_test_mode(False)
|
||||
|
||||
# This "manual" setup function is needed due to some bad interaction between
|
||||
# `setUp` and `combinations` that causes the dataset to be out-of-scope.
|
||||
|
|
|
|||
|
|
@ -1102,6 +1102,7 @@ py_library(
|
|||
"//tensorflow/python/data/experimental/ops:lookup_ops",
|
||||
"//tensorflow/python/data/experimental/ops:random_access",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:test_mode",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/ops/ragged",
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
options.experimental_optimization.noop_elimination = True
|
||||
options.experimental_optimization.parallel_batch = True
|
||||
options.experimental_optimization.shuffle_and_repeat_fusion = True
|
||||
options.experimental_optimization.warm_start = True
|
||||
options.experimental_slack = True
|
||||
options.threading.max_intra_op_parallelism = 30
|
||||
options.threading.private_threadpool_size = 40
|
||||
|
|
@ -177,6 +178,7 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
dataset_options_pb2.DistributeOptions())
|
||||
expected_pb.optimization_options.CopyFrom(
|
||||
dataset_options_pb2.OptimizationOptions())
|
||||
expected_pb.optimization_options.warm_start = True
|
||||
expected_pb.threading_options.CopyFrom(
|
||||
dataset_options_pb2.ThreadingOptions())
|
||||
self.assertProtoEquals(expected_pb, result)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import re
|
|||
from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops
|
||||
from tensorflow.python.data.experimental.ops import random_access
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import test_mode
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.eager import context
|
||||
|
|
@ -72,6 +73,10 @@ def v2_eager_only_combinations():
|
|||
class DatasetTestBase(test.TestCase):
|
||||
"""Base class for dataset tests."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
test_mode.toggle_test_mode(True)
|
||||
|
||||
def assert_op_cancelled(self, op):
|
||||
with self.assertRaises(errors.CancelledError):
|
||||
self.evaluate(op)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,11 @@ py_library(
|
|||
srcs = ["debug_mode.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_mode",
|
||||
srcs = ["test_mode.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "batch_op",
|
||||
srcs = ["batch_op.py"],
|
||||
|
|
@ -416,6 +421,7 @@ py_library(
|
|||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:test_mode",
|
||||
"//tensorflow/python/data/util:options",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from absl import logging
|
|||
|
||||
from tensorflow.core.framework import dataset_options_pb2
|
||||
from tensorflow.core.framework import model_pb2
|
||||
from tensorflow.python.data.ops import test_mode
|
||||
from tensorflow.python.data.util import options as options_lib
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
|
@ -387,6 +388,20 @@ class OptimizationOptions(options_lib.OptionsBase):
|
|||
docstring="Whether to fuse shuffle and repeat transformations. If None, "
|
||||
"defaults to True.")
|
||||
|
||||
warm_start = options_lib.create_option(
|
||||
name="warm_start",
|
||||
ty=bool,
|
||||
docstring=(
|
||||
"Whether to start background threads of asynchronous transformations"
|
||||
" upon iterator creation (as opposed to upon first call to"
|
||||
" `GetNext`). If None, defaults to False. It should be noted that"
|
||||
" this possibly improves the latency of the initial 'GetNext' call at"
|
||||
" the expense of requiring more memory to hold prefetched elements"
|
||||
" between the time of iterator construction and usage."
|
||||
),
|
||||
default_factory=lambda: True if test_mode.TEST_MODE else False,
|
||||
)
|
||||
|
||||
def _to_proto(self):
|
||||
pb = dataset_options_pb2.OptimizationOptions()
|
||||
if self.apply_default_optimizations is not None:
|
||||
|
|
@ -411,6 +426,8 @@ class OptimizationOptions(options_lib.OptionsBase):
|
|||
pb.parallel_batch = self.parallel_batch
|
||||
if self.shuffle_and_repeat_fusion is not None:
|
||||
pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion
|
||||
if self.warm_start is not None:
|
||||
pb.warm_start = self.warm_start
|
||||
return pb
|
||||
|
||||
def _from_proto(self, pb):
|
||||
|
|
@ -436,6 +453,8 @@ class OptimizationOptions(options_lib.OptionsBase):
|
|||
self.parallel_batch = pb.parallel_batch
|
||||
if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
|
||||
self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
|
||||
if pb.WhichOneof("optional_warm_start") is not None:
|
||||
self.warm_start = pb.warm_start
|
||||
|
||||
def _set_mutable(self, mutable):
|
||||
"""Change the mutability value to `mutable` on this options and children."""
|
||||
|
|
|
|||
32
tensorflow/python/data/ops/test_mode.py
Normal file
32
tensorflow/python/data/ops/test_mode.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python test mode enabler.
|
||||
|
||||
Enables test mode for tf.data.
|
||||
|
||||
The test mode can be used to set up custom values for features and
|
||||
experiments as required in the unit tests.
|
||||
|
||||
For example, if `warm_start` feature needs to be enabled exclusively for the
|
||||
unit tests, the tests can enable the test mode using `toggle_test_mode` and
|
||||
the default value of `warm_start` can be set as per the value of `TEST_MODE`.
|
||||
"""
|
||||
|
||||
TEST_MODE = False
|
||||
|
||||
|
||||
def toggle_test_mode(test_mode):
|
||||
global TEST_MODE
|
||||
TEST_MODE = test_mode
|
||||
|
|
@ -47,6 +47,10 @@ tf_class {
|
|||
name: "shuffle_and_repeat_fusion"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "warm_start"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
|
|
@ -47,6 +47,10 @@ tf_class {
|
|||
name: "shuffle_and_repeat_fusion"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "warm_start"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user