mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[tf.data] Implement warm_start feature for all the asynchronous operations.
PiperOrigin-RevId: 509533594
This commit is contained in:
parent
dee9915773
commit
d3509a44ca
10
RELEASE.md
10
RELEASE.md
|
|
@ -202,7 +202,15 @@ This release contains contributions from many people at Google, as well as:
|
||||||
`rerandomize_each_iteration=True`, the `sample_from_datasets()`
|
`rerandomize_each_iteration=True`, the `sample_from_datasets()`
|
||||||
operation will use a different (deterministic) sequence of numbers every
|
operation will use a different (deterministic) sequence of numbers every
|
||||||
epoch.
|
epoch.
|
||||||
|
* Added a new field, `warm_start`, to
|
||||||
|
`tf.data.experimental.OptimizationOptions`. If it is set to `True`,
|
||||||
|
tf.data will start background threads of asynchronous
|
||||||
|
transformations upon iterator creation (as opposed to upon first call
|
||||||
|
to `GetNext`). To enable this behavior, set `warm_start=True` in
|
||||||
|
`tf.data.experimental.OptimizationOptions`. It should be noted that this
|
||||||
|
possibly improves the latency of the initial 'GetNext' call at the
|
||||||
|
expense of requiring more memory to hold prefetched elements between
|
||||||
|
the time of iterator construction and usage.
|
||||||
* `tf.test`:
|
* `tf.test`:
|
||||||
|
|
||||||
* Added `tf.test.experimental.sync_devices`, which is useful for
|
* Added `tf.test.experimental.sync_devices`, which is useful for
|
||||||
|
|
|
||||||
|
|
@ -96,6 +96,7 @@ constexpr char kSlackOpt[] = "slack";
|
||||||
constexpr char kSlackPeriodOpt[] = "slack_period";
|
constexpr char kSlackPeriodOpt[] = "slack_period";
|
||||||
constexpr char kMakeDeterministicOpt[] = "make_deterministic";
|
constexpr char kMakeDeterministicOpt[] = "make_deterministic";
|
||||||
constexpr char kFilterParallelizationOpt[] = "filter_parallelization";
|
constexpr char kFilterParallelizationOpt[] = "filter_parallelization";
|
||||||
|
constexpr char kWarmStartOpt[] = "warm_start";
|
||||||
|
|
||||||
void DefaultOptimizationGraphRewrites(
|
void DefaultOptimizationGraphRewrites(
|
||||||
const Options& options, absl::flat_hash_set<tstring>* optimization_enabled,
|
const Options& options, absl::flat_hash_set<tstring>* optimization_enabled,
|
||||||
|
|
@ -213,6 +214,14 @@ void DefaultOptimizationGraphRewrites(
|
||||||
optimization_disabled->insert(kInjectPrefetchOpt);
|
optimization_disabled->insert(kInjectPrefetchOpt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (optimization_options.optional_warm_start_case() ==
|
||||||
|
OptimizationOptions::kWarmStart) {
|
||||||
|
if (optimization_options.warm_start()) {
|
||||||
|
optimization_enabled->insert(kWarmStartOpt);
|
||||||
|
} else {
|
||||||
|
optimization_disabled->insert(kWarmStartOpt);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether an op has been allowlisted as stateless. Uses a heuristic to
|
// Returns whether an op has been allowlisted as stateless. Uses a heuristic to
|
||||||
|
|
|
||||||
|
|
@ -617,15 +617,17 @@ GetOptimizationsTestCase GetOptimizationTestCase4() {
|
||||||
options.mutable_optimization_options()->set_parallel_batch(true);
|
options.mutable_optimization_options()->set_parallel_batch(true);
|
||||||
options.mutable_optimization_options()->set_shuffle_and_repeat_fusion(true);
|
options.mutable_optimization_options()->set_shuffle_and_repeat_fusion(true);
|
||||||
options.mutable_optimization_options()->set_inject_prefetch(true);
|
options.mutable_optimization_options()->set_inject_prefetch(true);
|
||||||
|
options.mutable_optimization_options()->set_warm_start(true);
|
||||||
options.set_slack(true);
|
options.set_slack(true);
|
||||||
return {options,
|
return {
|
||||||
/*expected_enabled=*/
|
options,
|
||||||
{"filter_fusion", "filter_parallelization", "make_sloppy",
|
/*expected_enabled=*/
|
||||||
"map_and_batch_fusion", "map_and_filter_fusion", "map_fusion",
|
{"filter_fusion", "filter_parallelization", "make_sloppy",
|
||||||
"map_parallelization", "noop_elimination", "parallel_batch",
|
"map_and_batch_fusion", "map_and_filter_fusion", "map_fusion",
|
||||||
"shuffle_and_repeat_fusion", "slack", "inject_prefetch"},
|
"map_parallelization", "noop_elimination", "parallel_batch",
|
||||||
/*expected_disabled=*/{},
|
"shuffle_and_repeat_fusion", "slack", "inject_prefetch", "warm_start"},
|
||||||
/*expected_default=*/{}};
|
/*expected_disabled=*/{},
|
||||||
|
/*expected_default=*/{}};
|
||||||
}
|
}
|
||||||
|
|
||||||
class GetOptimizationsTest
|
class GetOptimizationsTest
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/data/dataset_utils.h"
|
#include "tensorflow/core/data/dataset_utils.h"
|
||||||
#include "tensorflow/core/data/name_utils.h"
|
#include "tensorflow/core/data/name_utils.h"
|
||||||
#include "tensorflow/core/data/rewrite_utils.h"
|
#include "tensorflow/core/data/rewrite_utils.h"
|
||||||
|
#include "tensorflow/core/framework/dataset_options.pb.h"
|
||||||
#include "tensorflow/core/framework/model.pb.h"
|
#include "tensorflow/core/framework/model.pb.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/host_info.h"
|
#include "tensorflow/core/platform/host_info.h"
|
||||||
|
|
@ -47,6 +48,7 @@ constexpr char kPrivateThreadpoolSize[] = "threadpool_size";
|
||||||
constexpr char kRamBudget[] = "ram_budget_megabytes";
|
constexpr char kRamBudget[] = "ram_budget_megabytes";
|
||||||
constexpr char kRamUsage[] = "ram_usage_megabytes";
|
constexpr char kRamUsage[] = "ram_usage_megabytes";
|
||||||
constexpr char kMaxBufferBytes[] = "max_buffered_megabytes";
|
constexpr char kMaxBufferBytes[] = "max_buffered_megabytes";
|
||||||
|
constexpr char kWarmStart[] = "warm_start";
|
||||||
|
|
||||||
// If value `x` matches `y`, returns default value `z`. Otherwise, return `x`.
|
// If value `x` matches `y`, returns default value `z`. Otherwise, return `x`.
|
||||||
inline int64_t value_or_default(int64_t x, int64_t y, int64_t z) {
|
inline int64_t value_or_default(int64_t x, int64_t y, int64_t z) {
|
||||||
|
|
@ -83,7 +85,7 @@ void SetRootDatasetParams(const Options& options, RootDataset::Params* params) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddTraceMetadata(const RootDataset::Params& params,
|
void AddTraceMetadata(const RootDataset::Params& params, const Options& options,
|
||||||
TraceMeMetadata* trace_metadata) {
|
TraceMeMetadata* trace_metadata) {
|
||||||
if (params.autotune) {
|
if (params.autotune) {
|
||||||
trace_metadata->push_back(std::make_pair(
|
trace_metadata->push_back(std::make_pair(
|
||||||
|
|
@ -115,6 +117,9 @@ void AddTraceMetadata(const RootDataset::Params& params,
|
||||||
trace_metadata->push_back(
|
trace_metadata->push_back(
|
||||||
std::make_pair(kExperiments, absl::StrJoin(experiments, " ")));
|
std::make_pair(kExperiments, absl::StrJoin(experiments, " ")));
|
||||||
}
|
}
|
||||||
|
trace_metadata->push_back(std::make_pair(
|
||||||
|
kWarmStart,
|
||||||
|
options.optimization_options().warm_start() ? "true" : "false"));
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
@ -307,7 +312,7 @@ RootDataset::RootDataset(const DatasetBase* input, const Params& params)
|
||||||
name_utils::OpName(kDatasetType)})),
|
name_utils::OpName(kDatasetType)})),
|
||||||
input_(input),
|
input_(input),
|
||||||
params_(std::move(params)) {
|
params_(std::move(params)) {
|
||||||
AddTraceMetadata(params_, &traceme_metadata_);
|
AddTraceMetadata(params_, input_->options(), &traceme_metadata_);
|
||||||
}
|
}
|
||||||
|
|
||||||
RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input,
|
RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input,
|
||||||
|
|
@ -317,7 +322,7 @@ RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input,
|
||||||
params_(std::move(params)) {
|
params_(std::move(params)) {
|
||||||
owned_input_ = std::move(input);
|
owned_input_ = std::move(input);
|
||||||
input_ = owned_input_.get();
|
input_ = owned_input_.get();
|
||||||
AddTraceMetadata(params_, &traceme_metadata_);
|
AddTraceMetadata(params_, input_->options(), &traceme_metadata_);
|
||||||
}
|
}
|
||||||
|
|
||||||
RootDataset::~RootDataset() = default;
|
RootDataset::~RootDataset() = default;
|
||||||
|
|
|
||||||
|
|
@ -750,7 +750,8 @@ class IteratorContext {
|
||||||
symbolic_checkpoint(ctx->symbolic_checkpoint()),
|
symbolic_checkpoint(ctx->symbolic_checkpoint()),
|
||||||
thread_factory(ctx->thread_factory()),
|
thread_factory(ctx->thread_factory()),
|
||||||
thread_pool(ctx->thread_pool()),
|
thread_pool(ctx->thread_pool()),
|
||||||
id_registry(ctx->id_registry()) {}
|
id_registry(ctx->id_registry()),
|
||||||
|
warm_start(ctx->warm_start()) {}
|
||||||
|
|
||||||
explicit Params(OpKernelContext* ctx)
|
explicit Params(OpKernelContext* ctx)
|
||||||
: collective_executor(ctx->collective_executor()),
|
: collective_executor(ctx->collective_executor()),
|
||||||
|
|
@ -844,6 +845,11 @@ class IteratorContext {
|
||||||
|
|
||||||
std::shared_ptr<MemoryCheckpoint::IdRegistry> id_registry =
|
std::shared_ptr<MemoryCheckpoint::IdRegistry> id_registry =
|
||||||
std::make_shared<MemoryCheckpoint::IdRegistry>();
|
std::make_shared<MemoryCheckpoint::IdRegistry>();
|
||||||
|
|
||||||
|
// If `true` background threads of asynchronous operations are started when
|
||||||
|
// the iterator is created. Otherwise, they are started upon first `GetNext`
|
||||||
|
// request. Default value is set to false to ensure backward compatibility.
|
||||||
|
bool warm_start = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
explicit IteratorContext(IteratorContext* ctx)
|
explicit IteratorContext(IteratorContext* ctx)
|
||||||
|
|
@ -922,6 +928,8 @@ class IteratorContext {
|
||||||
|
|
||||||
thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
|
thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
|
||||||
|
|
||||||
|
bool warm_start() { return params_.warm_start; }
|
||||||
|
|
||||||
std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
|
std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
|
||||||
int num_threads) {
|
int num_threads) {
|
||||||
if (params_.thread_pool) {
|
if (params_.thread_pool) {
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ message DistributeOptions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// next: 20
|
// next: 21
|
||||||
message OptimizationOptions {
|
message OptimizationOptions {
|
||||||
// Whether to apply default graph optimizations. If False, only graph
|
// Whether to apply default graph optimizations. If False, only graph
|
||||||
// optimizations that have been explicitly enabled will be applied.
|
// optimizations that have been explicitly enabled will be applied.
|
||||||
|
|
@ -149,6 +149,11 @@ message OptimizationOptions {
|
||||||
oneof optional_inject_prefetch {
|
oneof optional_inject_prefetch {
|
||||||
bool inject_prefetch = 19;
|
bool inject_prefetch = 19;
|
||||||
}
|
}
|
||||||
|
// Whether to start background threads of asynchronous transformations upon
|
||||||
|
// iterator creation (as opposed to upon first call to `GetNext`).
|
||||||
|
oneof optional_warm_start {
|
||||||
|
bool warm_start = 20;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// next: 3
|
// next: 3
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.h"
|
#include "tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.h"
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <functional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
|
@ -233,8 +234,12 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||||
&iter_ctx, this, prefix(), &input_impl_));
|
&iter_ctx, this, prefix(), &input_impl_));
|
||||||
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
||||||
return dataset()->captured_func_->Instantiate(
|
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
||||||
ctx, &instantiated_captured_func_);
|
ctx, &instantiated_captured_func_));
|
||||||
|
if (ctx->warm_start() && !ctx->is_restoring()) {
|
||||||
|
EnsureThreadsStarted(ctx);
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
|
@ -243,7 +248,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||||
std::shared_ptr<BatchResult> result;
|
std::shared_ptr<BatchResult> result;
|
||||||
{
|
{
|
||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
EnsureRunnerThreadStarted(ctx);
|
EnsureThreadsStarted(ctx);
|
||||||
while (!cancelled_ && (batch_results_.empty() ||
|
while (!cancelled_ && (batch_results_.empty() ||
|
||||||
batch_results_.front()->num_calls > 0)) {
|
batch_results_.front()->num_calls > 0)) {
|
||||||
++waiting_;
|
++waiting_;
|
||||||
|
|
@ -316,6 +321,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||||
Status RestoreInternal(IteratorContext* ctx,
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
|
DCHECK(!runner_thread_);
|
||||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
reader->ReadScalar(full_name(kCallCounter), &call_counter_));
|
reader->ReadScalar(full_name(kCallCounter), &call_counter_));
|
||||||
|
|
@ -326,6 +332,9 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||||
for (int i = 0; i < batch_results_size; ++i) {
|
for (int i = 0; i < batch_results_size; ++i) {
|
||||||
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
|
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
|
||||||
}
|
}
|
||||||
|
if (ctx->warm_start()) {
|
||||||
|
EnsureThreadsStarted(ctx);
|
||||||
|
}
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -510,13 +519,13 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
void EnsureThreadsStarted(IteratorContext* ctx)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||||
if (!runner_thread_) {
|
if (!runner_thread_) {
|
||||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||||
runner_thread_ = ctx->StartThread(
|
runner_thread_ =
|
||||||
kTFDataMapAndBatch,
|
ctx->StartThread(kTFDataMapAndBatch,
|
||||||
std::bind(&Iterator::RunnerThread, this, ctx_copy));
|
std::bind(&Iterator::RunnerThread, this, new_ctx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -126,6 +126,7 @@ Status IteratorResource::GetNext(OpKernelContext* ctx,
|
||||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||||
params.thread_pool = &unbounded_thread_pool_;
|
params.thread_pool = &unbounded_thread_pool_;
|
||||||
params.id_registry = captured_state->id_registry();
|
params.id_registry = captured_state->id_registry();
|
||||||
|
params.warm_start = dataset->options().optimization_options().warm_start();
|
||||||
std::function<void()> deregister_fn;
|
std::function<void()> deregister_fn;
|
||||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||||
ctx->cancellation_manager(),
|
ctx->cancellation_manager(),
|
||||||
|
|
@ -248,6 +249,7 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
||||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||||
params.thread_pool = &unbounded_thread_pool_;
|
params.thread_pool = &unbounded_thread_pool_;
|
||||||
params.id_registry = new_state->id_registry();
|
params.id_registry = new_state->id_registry();
|
||||||
|
params.warm_start = dataset->options().optimization_options().warm_start();
|
||||||
std::function<void()> deregister_fn;
|
std::function<void()> deregister_fn;
|
||||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||||
ctx->cancellation_manager(),
|
ctx->cancellation_manager(),
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/kernels/data/parallel_batch_dataset_op.h"
|
#include "tensorflow/core/kernels/data/parallel_batch_dataset_op.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
|
@ -219,8 +220,12 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
||||||
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
|
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
|
||||||
IteratorContext::Params params(ctx);
|
IteratorContext::Params params(ctx);
|
||||||
params.cancellation_manager = cancellation_manager_.get();
|
params.cancellation_manager = cancellation_manager_.get();
|
||||||
return dataset()->input_->MakeIterator(IteratorContext(params), this,
|
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||||
prefix(), &input_impl_);
|
IteratorContext(params), this, prefix(), &input_impl_));
|
||||||
|
if (ctx->warm_start() && !ctx->is_restoring()) {
|
||||||
|
EnsureThreadsStarted(ctx);
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
|
@ -229,7 +234,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
||||||
std::shared_ptr<BatchResult> result;
|
std::shared_ptr<BatchResult> result;
|
||||||
{
|
{
|
||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
EnsureRunnerThreadStarted(ctx);
|
EnsureThreadsStarted(ctx);
|
||||||
while (ShouldWait(&result)) {
|
while (ShouldWait(&result)) {
|
||||||
RecordStop(ctx);
|
RecordStop(ctx);
|
||||||
cond_var_->wait(l);
|
cond_var_->wait(l);
|
||||||
|
|
@ -289,6 +294,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
||||||
Status RestoreInternal(IteratorContext* ctx,
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
|
DCHECK(!runner_thread_);
|
||||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||||
int64_t batch_results_size;
|
int64_t batch_results_size;
|
||||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBatchResultsSize),
|
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBatchResultsSize),
|
||||||
|
|
@ -297,6 +303,9 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
||||||
for (int i = 0; i < batch_results_size; ++i) {
|
for (int i = 0; i < batch_results_size; ++i) {
|
||||||
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
|
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
|
||||||
}
|
}
|
||||||
|
if (ctx->warm_start()) {
|
||||||
|
EnsureThreadsStarted(ctx);
|
||||||
|
}
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -432,13 +441,13 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void EnsureRunnerThreadStarted(IteratorContext* ctx)
|
void EnsureThreadsStarted(IteratorContext* ctx)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||||
if (!runner_thread_) {
|
if (!runner_thread_) {
|
||||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||||
runner_thread_ = ctx->StartThread(
|
runner_thread_ =
|
||||||
kTFDataParallelBatch,
|
ctx->StartThread(kTFDataParallelBatch,
|
||||||
std::bind(&Iterator::RunnerThread, this, ctx_copy));
|
std::bind(&Iterator::RunnerThread, this, new_ctx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -395,8 +395,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||||
params.cancellation_manager = cancellation_manager_.get();
|
params.cancellation_manager = cancellation_manager_.get();
|
||||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||||
IteratorContext(params), this, prefix(), &input_impl_));
|
IteratorContext(params), this, prefix(), &input_impl_));
|
||||||
return dataset()->captured_func_->Instantiate(
|
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
||||||
ctx, &instantiated_captured_func_);
|
ctx, &instantiated_captured_func_));
|
||||||
|
if (ctx->warm_start() && !ctx->is_restoring()) {
|
||||||
|
EnsureInitialElementsCreated();
|
||||||
|
EnsureThreadsStarted();
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
|
@ -515,7 +520,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
{
|
{
|
||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
DCHECK(!threads_initialized_);
|
DCHECK(!threads_started_);
|
||||||
DCHECK(!initial_elements_created_);
|
DCHECK(!initial_elements_created_);
|
||||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
|
|
@ -548,6 +553,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||||
!current_elements_[last_valid_current_element_]) {
|
!current_elements_[last_valid_current_element_]) {
|
||||||
last_valid_current_element_--;
|
last_valid_current_element_--;
|
||||||
}
|
}
|
||||||
|
if (ctx->warm_start()) {
|
||||||
|
EnsureInitialElementsCreated();
|
||||||
|
EnsureThreadsStarted();
|
||||||
|
}
|
||||||
VLOG(2) << "Parallel interleave iterator restored";
|
VLOG(2) << "Parallel interleave iterator restored";
|
||||||
VLOG(4) << "State after restore:\n" << DebugString();
|
VLOG(4) << "State after restore:\n" << DebugString();
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
|
|
@ -690,14 +699,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
if (!threads_initialized_) {
|
if (!threads_started_) {
|
||||||
IncrementOutstandingThreads();
|
IncrementOutstandingThreads();
|
||||||
thread_pool_->Schedule([this]() { WorkerManagerThread(); });
|
thread_pool_->Schedule([this]() { WorkerManagerThread(); });
|
||||||
if (ctx_->stats_aggregator()) {
|
if (ctx_->stats_aggregator()) {
|
||||||
IncrementOutstandingThreads();
|
IncrementOutstandingThreads();
|
||||||
thread_pool_->Schedule([this]() { StatsThread(); });
|
thread_pool_->Schedule([this]() { StatsThread(); });
|
||||||
}
|
}
|
||||||
threads_initialized_ = true;
|
threads_started_ = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1554,8 +1563,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||||
// Identifies whether the current_elements_ vector has been initialized.
|
// Identifies whether the current_elements_ vector has been initialized.
|
||||||
bool initial_elements_created_ TF_GUARDED_BY(mu_) = false;
|
bool initial_elements_created_ TF_GUARDED_BY(mu_) = false;
|
||||||
|
|
||||||
// Identifies whether the element threads have been initialized.
|
// Identifies whether the element threads have been started.
|
||||||
bool threads_initialized_ TF_GUARDED_BY(mu_) = false;
|
bool threads_started_ TF_GUARDED_BY(mu_) = false;
|
||||||
|
|
||||||
// Used for coordination between the main thread, the manager threads, and
|
// Used for coordination between the main thread, the manager threads, and
|
||||||
// the worker threads.
|
// the worker threads.
|
||||||
|
|
|
||||||
|
|
@ -276,8 +276,12 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||||
&iter_ctx, this, prefix(), &input_impl_));
|
&iter_ctx, this, prefix(), &input_impl_));
|
||||||
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
||||||
return dataset()->captured_func_->Instantiate(
|
TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
|
||||||
ctx, &instantiated_captured_func_);
|
ctx, &instantiated_captured_func_));
|
||||||
|
if (ctx->warm_start() && !ctx->is_restoring()) {
|
||||||
|
EnsureThreadsStarted(ctx);
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
|
|
|
||||||
|
|
@ -181,6 +181,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||||
IteratorContext iter_ctx(params);
|
IteratorContext iter_ctx(params);
|
||||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||||
&iter_ctx, this, prefix(), &input_impl_));
|
&iter_ctx, this, prefix(), &input_impl_));
|
||||||
|
if (ctx->warm_start() && !ctx->is_restoring()) {
|
||||||
|
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
|
||||||
|
}
|
||||||
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
@ -191,7 +194,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||||
const auto& stats_aggregator = ctx->stats_aggregator();
|
const auto& stats_aggregator = ctx->stats_aggregator();
|
||||||
{
|
{
|
||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
|
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
|
||||||
// Wait until the next element in the buffer has been
|
// Wait until the next element in the buffer has been
|
||||||
// produced, or we are shutting down.
|
// produced, or we are shutting down.
|
||||||
while (buffer_.empty() && !prefetch_thread_finished_ &&
|
while (buffer_.empty() && !prefetch_thread_finished_ &&
|
||||||
|
|
@ -283,6 +286,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
mutex_lock input_l(input_mu_);
|
mutex_lock input_l(input_mu_);
|
||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
|
DCHECK(!prefetch_thread_);
|
||||||
DCHECK(buffer_.empty());
|
DCHECK(buffer_.empty());
|
||||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||||
size_t buffer_size;
|
size_t buffer_size;
|
||||||
|
|
@ -315,6 +319,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||||
}
|
}
|
||||||
RecordBufferEnqueue(ctx, buffer_element.value);
|
RecordBufferEnqueue(ctx, buffer_element.value);
|
||||||
}
|
}
|
||||||
|
if (ctx->warm_start()) {
|
||||||
|
TF_RETURN_IF_ERROR(EnsureThreadsStarted(ctx));
|
||||||
|
}
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -458,7 +465,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
|
Status EnsureThreadsStarted(IteratorContext* ctx)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||||
if (!prefetch_thread_) {
|
if (!prefetch_thread_) {
|
||||||
std::shared_ptr<IteratorContext> new_ctx =
|
std::shared_ptr<IteratorContext> new_ctx =
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for the static tf.data optimizations."""
|
"""Tests for the static tf.data optimizations."""
|
||||||
import functools
|
import functools
|
||||||
|
import time
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -32,6 +33,7 @@ from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -245,6 +247,7 @@ class OptimizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
options.experimental_optimization.apply_default_optimizations = False
|
options.experimental_optimization.apply_default_optimizations = False
|
||||||
options.experimental_optimization.noop_elimination = True
|
options.experimental_optimization.noop_elimination = True
|
||||||
options.experimental_optimization.map_and_batch_fusion = True
|
options.experimental_optimization.map_and_batch_fusion = True
|
||||||
|
options.experimental_optimization.warm_start = False
|
||||||
optimized_dataset = unoptimized_dataset.with_options(options)
|
optimized_dataset = unoptimized_dataset.with_options(options)
|
||||||
optimized_it = dataset_ops.make_initializable_iterator(optimized_dataset)
|
optimized_it = dataset_ops.make_initializable_iterator(optimized_dataset)
|
||||||
|
|
||||||
|
|
@ -266,6 +269,40 @@ class OptimizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
except errors.OutOfRangeError:
|
except errors.OutOfRangeError:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(
|
||||||
|
test_base.eager_only_combinations(),
|
||||||
|
combinations.combine(warm_start=[True, False]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def testOptimizationWarmStart(self, warm_start):
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
counter = variables.Variable(0)
|
||||||
|
|
||||||
|
def update_counter(x):
|
||||||
|
counter.assign_add(1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
options = options_lib.Options()
|
||||||
|
options.experimental_optimization.apply_default_optimizations = False
|
||||||
|
if warm_start:
|
||||||
|
options.experimental_optimization.warm_start = True
|
||||||
|
else:
|
||||||
|
options.experimental_optimization.warm_start = False
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
|
dataset = dataset.map(update_counter).prefetch(10)
|
||||||
|
unused_iter = iter(dataset)
|
||||||
|
|
||||||
|
if warm_start:
|
||||||
|
for sleep_time_secs in [0.1, 0.2, 0.5, 2, 5, 10]:
|
||||||
|
if counter.numpy() == 0:
|
||||||
|
time.sleep(sleep_time_secs)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
self.assertGreater(counter.numpy(), 0)
|
||||||
|
else:
|
||||||
|
self.assertEqual(counter.numpy(), 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
||||||
|
|
@ -563,7 +563,7 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||||
job_gc_check_interval_ms=50,
|
job_gc_check_interval_ms=50,
|
||||||
job_gc_timeout_ms=20,
|
job_gc_timeout_ms=20,
|
||||||
data_transfer_protocol=self._get_data_transfer_protocol())
|
data_transfer_protocol=self._get_data_transfer_protocol())
|
||||||
num_elements = 10
|
num_elements = 1000
|
||||||
it1 = iter(
|
it1 = iter(
|
||||||
self.make_distributed_range_dataset(
|
self.make_distributed_range_dataset(
|
||||||
num_elements,
|
num_elements,
|
||||||
|
|
|
||||||
|
|
@ -21,14 +21,21 @@ from tensorflow.python.data.experimental.kernel_tests.service import test_base a
|
||||||
from tensorflow.python.data.experimental.ops import data_service_ops
|
from tensorflow.python.data.experimental.ops import data_service_ops
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.data.ops import test_mode
|
||||||
from tensorflow.python.framework import combinations
|
from tensorflow.python.framework import combinations
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
|
|
||||||
class LocalWorkersTest(data_service_test_base.TestBase, parameterized.TestCase):
|
class LocalWorkersTest(data_service_test_base.TestBase, parameterized.TestCase):
|
||||||
"""Tests reading from local workers if `target_workers` is `local`."""
|
"""Tests reading from local workers if `target_workers` is `local`."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
# TODO(b/268586701): Enable `warm_start` for `local_workers_test`.
|
||||||
|
test_mode.toggle_test_mode(False)
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testOneLocalWorker(self):
|
def testOneLocalWorker(self):
|
||||||
cluster = multi_process_cluster.MultiProcessCluster(
|
cluster = multi_process_cluster.MultiProcessCluster(
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from tensorflow.python.data.experimental.kernel_tests.service import test_base a
|
||||||
from tensorflow.python.data.experimental.ops import distributed_save_op
|
from tensorflow.python.data.experimental.ops import distributed_save_op
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.data.ops import test_mode
|
||||||
from tensorflow.python.framework import combinations
|
from tensorflow.python.framework import combinations
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
@ -62,6 +63,8 @@ class SnapshotFtTest(data_service_test_base.TestBase, parameterized.TestCase):
|
||||||
tempfile.mkdtemp(dir=self.get_temp_dir()),
|
tempfile.mkdtemp(dir=self.get_temp_dir()),
|
||||||
"snapshot_ft_test",
|
"snapshot_ft_test",
|
||||||
)
|
)
|
||||||
|
# TODO(b/268586560): Enable `warm_start` for `snapshot_ft_test`.
|
||||||
|
test_mode.toggle_test_mode(False)
|
||||||
|
|
||||||
# This "manual" setup function is needed due to some bad interaction between
|
# This "manual" setup function is needed due to some bad interaction between
|
||||||
# `setUp` and `combinations` that causes the dataset to be out-of-scope.
|
# `setUp` and `combinations` that causes the dataset to be out-of-scope.
|
||||||
|
|
|
||||||
|
|
@ -1102,6 +1102,7 @@ py_library(
|
||||||
"//tensorflow/python/data/experimental/ops:lookup_ops",
|
"//tensorflow/python/data/experimental/ops:lookup_ops",
|
||||||
"//tensorflow/python/data/experimental/ops:random_access",
|
"//tensorflow/python/data/experimental/ops:random_access",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"//tensorflow/python/data/ops:test_mode",
|
||||||
"//tensorflow/python/data/util:nest",
|
"//tensorflow/python/data/util:nest",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/ops/ragged",
|
"//tensorflow/python/ops/ragged",
|
||||||
|
|
|
||||||
|
|
@ -149,6 +149,7 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
options.experimental_optimization.noop_elimination = True
|
options.experimental_optimization.noop_elimination = True
|
||||||
options.experimental_optimization.parallel_batch = True
|
options.experimental_optimization.parallel_batch = True
|
||||||
options.experimental_optimization.shuffle_and_repeat_fusion = True
|
options.experimental_optimization.shuffle_and_repeat_fusion = True
|
||||||
|
options.experimental_optimization.warm_start = True
|
||||||
options.experimental_slack = True
|
options.experimental_slack = True
|
||||||
options.threading.max_intra_op_parallelism = 30
|
options.threading.max_intra_op_parallelism = 30
|
||||||
options.threading.private_threadpool_size = 40
|
options.threading.private_threadpool_size = 40
|
||||||
|
|
@ -177,6 +178,7 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
dataset_options_pb2.DistributeOptions())
|
dataset_options_pb2.DistributeOptions())
|
||||||
expected_pb.optimization_options.CopyFrom(
|
expected_pb.optimization_options.CopyFrom(
|
||||||
dataset_options_pb2.OptimizationOptions())
|
dataset_options_pb2.OptimizationOptions())
|
||||||
|
expected_pb.optimization_options.warm_start = True
|
||||||
expected_pb.threading_options.CopyFrom(
|
expected_pb.threading_options.CopyFrom(
|
||||||
dataset_options_pb2.ThreadingOptions())
|
dataset_options_pb2.ThreadingOptions())
|
||||||
self.assertProtoEquals(expected_pb, result)
|
self.assertProtoEquals(expected_pb, result)
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import re
|
||||||
from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops
|
from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops
|
||||||
from tensorflow.python.data.experimental.ops import random_access
|
from tensorflow.python.data.experimental.ops import random_access
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.data.ops import test_mode
|
||||||
from tensorflow.python.data.util import nest
|
from tensorflow.python.data.util import nest
|
||||||
from tensorflow.python.data.util import structure
|
from tensorflow.python.data.util import structure
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
|
@ -72,6 +73,10 @@ def v2_eager_only_combinations():
|
||||||
class DatasetTestBase(test.TestCase):
|
class DatasetTestBase(test.TestCase):
|
||||||
"""Base class for dataset tests."""
|
"""Base class for dataset tests."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
test_mode.toggle_test_mode(True)
|
||||||
|
|
||||||
def assert_op_cancelled(self, op):
|
def assert_op_cancelled(self, op):
|
||||||
with self.assertRaises(errors.CancelledError):
|
with self.assertRaises(errors.CancelledError):
|
||||||
self.evaluate(op)
|
self.evaluate(op)
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,11 @@ py_library(
|
||||||
srcs = ["debug_mode.py"],
|
srcs = ["debug_mode.py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "test_mode",
|
||||||
|
srcs = ["test_mode.py"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "batch_op",
|
name = "batch_op",
|
||||||
srcs = ["batch_op.py"],
|
srcs = ["batch_op.py"],
|
||||||
|
|
@ -416,6 +421,7 @@ py_library(
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/data/ops:test_mode",
|
||||||
"//tensorflow/python/data/util:options",
|
"//tensorflow/python/data/util:options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from absl import logging
|
||||||
|
|
||||||
from tensorflow.core.framework import dataset_options_pb2
|
from tensorflow.core.framework import dataset_options_pb2
|
||||||
from tensorflow.core.framework import model_pb2
|
from tensorflow.core.framework import model_pb2
|
||||||
|
from tensorflow.python.data.ops import test_mode
|
||||||
from tensorflow.python.data.util import options as options_lib
|
from tensorflow.python.data.util import options as options_lib
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
@ -387,6 +388,20 @@ class OptimizationOptions(options_lib.OptionsBase):
|
||||||
docstring="Whether to fuse shuffle and repeat transformations. If None, "
|
docstring="Whether to fuse shuffle and repeat transformations. If None, "
|
||||||
"defaults to True.")
|
"defaults to True.")
|
||||||
|
|
||||||
|
warm_start = options_lib.create_option(
|
||||||
|
name="warm_start",
|
||||||
|
ty=bool,
|
||||||
|
docstring=(
|
||||||
|
"Whether to start background threads of asynchronous transformations"
|
||||||
|
" upon iterator creation (as opposed to upon first call to"
|
||||||
|
" `GetNext`). If None, defaults to False. It should be noted that"
|
||||||
|
" this possibly improves the latency of the initial 'GetNext' call at"
|
||||||
|
" the expense of requiring more memory to hold prefetched elements"
|
||||||
|
" between the time of iterator construction and usage."
|
||||||
|
),
|
||||||
|
default_factory=lambda: True if test_mode.TEST_MODE else False,
|
||||||
|
)
|
||||||
|
|
||||||
def _to_proto(self):
|
def _to_proto(self):
|
||||||
pb = dataset_options_pb2.OptimizationOptions()
|
pb = dataset_options_pb2.OptimizationOptions()
|
||||||
if self.apply_default_optimizations is not None:
|
if self.apply_default_optimizations is not None:
|
||||||
|
|
@ -411,6 +426,8 @@ class OptimizationOptions(options_lib.OptionsBase):
|
||||||
pb.parallel_batch = self.parallel_batch
|
pb.parallel_batch = self.parallel_batch
|
||||||
if self.shuffle_and_repeat_fusion is not None:
|
if self.shuffle_and_repeat_fusion is not None:
|
||||||
pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion
|
pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion
|
||||||
|
if self.warm_start is not None:
|
||||||
|
pb.warm_start = self.warm_start
|
||||||
return pb
|
return pb
|
||||||
|
|
||||||
def _from_proto(self, pb):
|
def _from_proto(self, pb):
|
||||||
|
|
@ -436,6 +453,8 @@ class OptimizationOptions(options_lib.OptionsBase):
|
||||||
self.parallel_batch = pb.parallel_batch
|
self.parallel_batch = pb.parallel_batch
|
||||||
if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
|
if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
|
||||||
self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
|
self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
|
||||||
|
if pb.WhichOneof("optional_warm_start") is not None:
|
||||||
|
self.warm_start = pb.warm_start
|
||||||
|
|
||||||
def _set_mutable(self, mutable):
|
def _set_mutable(self, mutable):
|
||||||
"""Change the mutability value to `mutable` on this options and children."""
|
"""Change the mutability value to `mutable` on this options and children."""
|
||||||
|
|
|
||||||
32
tensorflow/python/data/ops/test_mode.py
Normal file
32
tensorflow/python/data/ops/test_mode.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Python test mode enabler.
|
||||||
|
|
||||||
|
Enables test mode for tf.data.
|
||||||
|
|
||||||
|
The test mode can be used to set up custom values for features and
|
||||||
|
experiments as required in the unit tests.
|
||||||
|
|
||||||
|
For example, if `warm_start` feature needs to be enabled exclusively for the
|
||||||
|
unit tests, the tests can enable the test mode using `toggle_test_mode` and
|
||||||
|
the default value of `warm_start` can be set as per the value of `TEST_MODE`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
TEST_MODE = False
|
||||||
|
|
||||||
|
|
||||||
|
def toggle_test_mode(test_mode):
|
||||||
|
global TEST_MODE
|
||||||
|
TEST_MODE = test_mode
|
||||||
|
|
@ -47,6 +47,10 @@ tf_class {
|
||||||
name: "shuffle_and_repeat_fusion"
|
name: "shuffle_and_repeat_fusion"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "warm_start"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,10 @@ tf_class {
|
||||||
name: "shuffle_and_repeat_fusion"
|
name: "shuffle_and_repeat_fusion"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "warm_start"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user