mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[tf.data] Implementation of alternative checkpoint protocol, which makes it possible to represent the state of an input pipeline without having to store the contents of internal buffers. The alternative checkpointing logic can be enabled through the experimental_symbolic_checkpoint option of tf.data.Options(). Note that only a subset of tf.data operations supports the new checkpointing protocol.
PiperOrigin-RevId: 488559450
This commit is contained in:
parent
bb398ab732
commit
b0b081efd6
|
|
@ -48,6 +48,13 @@
|
|||
* Coordination service now works with `dtensor.initialize_accelerator_system`,
|
||||
and enabled by default.
|
||||
|
||||
* `tf.data`:
|
||||
* Added support for alternative checkpointing protocol which makes it
|
||||
possible to checkpoint the state of the input pipeline without having
|
||||
to store the contents of internal buffers. The new functionality can
|
||||
be enabled through the `experimental_symbolic_checkpointing` option of
|
||||
`tf.data.Options()`.
|
||||
|
||||
# Bug Fixes and Other Changes
|
||||
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
|
|
|
|||
|
|
@ -161,9 +161,14 @@ class RootDataset::Iterator : public DatasetIterator<RootDataset> {
|
|||
|
||||
~Iterator() override { cancellation_manager_->StartCancel(); }
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(IteratorContext(CreateParams(ctx)),
|
||||
this, prefix(), &input_impl_);
|
||||
IteratorContext iter_ctx(CreateParams(ctx));
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(&iter_ctx, this,
|
||||
prefix(), &input_impl_));
|
||||
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
|
|
@ -177,8 +182,10 @@ class RootDataset::Iterator : public DatasetIterator<RootDataset> {
|
|||
if (dataset()->params_.autotune) {
|
||||
TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
|
||||
out_tensors, end_of_sequence));
|
||||
IteratorContext iter_ctx(CreateParams(ctx));
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence));
|
||||
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
end_time_usec_ = std::max(ctx->env()->NowMicros(), end_time_usec_);
|
||||
|
|
@ -200,8 +207,9 @@ class RootDataset::Iterator : public DatasetIterator<RootDataset> {
|
|||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestoreInput(IteratorContext(CreateParams(ctx)), reader, input_impl_));
|
||||
IteratorContext iter_ctx(CreateParams(ctx));
|
||||
TF_RETURN_IF_ERROR(RestoreInput(&iter_ctx, reader, input_impl_));
|
||||
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -249,7 +257,6 @@ class RootDataset::Iterator : public DatasetIterator<RootDataset> {
|
|||
params.runner =
|
||||
RunnerWithMaxParallelism(params.runner, max_intra_op_parallelism_);
|
||||
}
|
||||
params.options = &dataset()->options();
|
||||
return params;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -725,6 +725,7 @@ Status DatasetBase::MakeIterator(
|
|||
Status s = (*iterator)->InitializeBase(ctx, parent);
|
||||
if (s.ok()) {
|
||||
s.Update((*iterator)->Initialize(ctx));
|
||||
ctx->SaveCheckpoint(iterator->get());
|
||||
}
|
||||
if (!s.ok()) {
|
||||
// Reset the iterator to avoid returning an uninitialized iterator.
|
||||
|
|
@ -934,6 +935,13 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
|
|||
}
|
||||
out_tensors->clear();
|
||||
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
|
||||
ctx->SaveCheckpoint(this);
|
||||
if (!SymbolicCheckpointCompatible()) {
|
||||
ctx->UpdateCheckpointStatus([this]() {
|
||||
return errors::Unimplemented(dataset()->type_string(),
|
||||
" does not support symbolic checkpointing.");
|
||||
});
|
||||
}
|
||||
if (TF_PREDICT_TRUE(s.ok())) {
|
||||
if (TF_PREDICT_TRUE(!*end_of_sequence)) {
|
||||
DCHECK_EQ(out_tensors->size(), dataset()->output_dtypes().size());
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
|
|
@ -46,10 +48,12 @@ limitations under the License.
|
|||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/tsl/platform/errors.h"
|
||||
#include "tensorflow/tsl/platform/thread_annotations.h"
|
||||
|
||||
// Polymorphic datasets should support all primitive TensorFlow
|
||||
// types. Use this macro to expand `m(T)` once for each primitive type
|
||||
|
|
@ -94,6 +98,7 @@ constexpr char kMetadata[] = "metadata";
|
|||
constexpr char kCardinalityAttrForRewrite[] = "_cardinality";
|
||||
|
||||
class DatasetBase;
|
||||
class IteratorContext;
|
||||
class SerializationContext;
|
||||
|
||||
inline bool IsTFDataFunction(const FunctionDef& func) {
|
||||
|
|
@ -167,6 +172,17 @@ class IteratorStateWriter {
|
|||
// iterator checkpoints should go through this function.
|
||||
std::string FullName(const std::string& prefix, const std::string& name);
|
||||
|
||||
// Interface for objects that can be checkpointed.
|
||||
class Checkpointable {
|
||||
public:
|
||||
Checkpointable() = default;
|
||||
virtual ~Checkpointable() = default;
|
||||
|
||||
virtual Status Save(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) = 0;
|
||||
virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) = 0;
|
||||
};
|
||||
|
||||
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
|
||||
class GraphDefBuilderWrapper {
|
||||
public:
|
||||
|
|
@ -381,6 +397,207 @@ class SplitProvider {
|
|||
// Returns the runner threadpool size from an OpKernelContext.
|
||||
int32_t GetRunnerThreadpoolSizeFromOpKernelContext(OpKernelContext* ctx);
|
||||
|
||||
// In-memory representation of a checkpoint. The checkpoint is represented as a
|
||||
// collection of key-value pairs and are expected to be written using the
|
||||
// `IteratorStateWriter` interface.
|
||||
//
|
||||
// The implementation is not thread-safe.
|
||||
class MemoryCheckpoint : public IteratorStateWriter {
|
||||
public:
|
||||
MemoryCheckpoint() = default;
|
||||
|
||||
// BEGIN implementation of `IteratorStateWriter` interface
|
||||
Status WriteScalar(StringPiece key, int64_t val) override {
|
||||
int_values_[key] = val;
|
||||
return OkStatus();
|
||||
}
|
||||
Status WriteScalar(StringPiece name, StringPiece key, int64_t val) override {
|
||||
return WriteScalar(FullName(string(name), string(key)), val);
|
||||
}
|
||||
Status WriteScalar(StringPiece key, const tstring& val) override {
|
||||
str_values_[key] = val;
|
||||
return OkStatus();
|
||||
}
|
||||
Status WriteScalar(StringPiece name, StringPiece key,
|
||||
const tstring& val) override {
|
||||
return WriteScalar(FullName(string(name), string(key)), val);
|
||||
}
|
||||
Status WriteTensor(StringPiece key, const Tensor& val) override {
|
||||
tensor_values_[key] = val;
|
||||
return OkStatus();
|
||||
}
|
||||
Status WriteTensor(StringPiece name, StringPiece key,
|
||||
const Tensor& val) override {
|
||||
return WriteTensor(FullName(string(name), string(key)), val);
|
||||
}
|
||||
// END implementation of `IteratorStateWriter` interface
|
||||
|
||||
// String representation for the in-memory checkpoint suitable for debugging.
|
||||
std::string DebugString() const {
|
||||
std::string result = absl::StrCat("status=", status_.ToString(), "\n");
|
||||
absl::StrAppend(&result, "number of integers: ", int_values_.size(), "\n");
|
||||
for (const auto& pair : int_values_) {
|
||||
absl::StrAppend(&result, " ", pair.first, " ", pair.second, "\n");
|
||||
}
|
||||
absl::StrAppend(&result, "number of strings: ", str_values_.size(), "\n");
|
||||
for (const auto& pair : str_values_) {
|
||||
absl::StrAppend(&result, " ", pair.first, " ", pair.second, "\n");
|
||||
}
|
||||
absl::StrAppend(&result, "number of tensors: ", tensor_values_.size(),
|
||||
"\n");
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the status of the in-memory checkpoint.
|
||||
Status GetStatus() const { return status_; }
|
||||
|
||||
// Merges key-values pair of another checkpoint with this checkpoint. If a key
|
||||
// exists with another checkpoint, then the key-value pair from the `other`
|
||||
// argument is used.
|
||||
void Merge(const MemoryCheckpoint& other) {
|
||||
if (!status_.ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!other.status_.ok()) {
|
||||
status_ = other.status_;
|
||||
int_values_.clear();
|
||||
str_values_.clear();
|
||||
tensor_values_.clear();
|
||||
}
|
||||
|
||||
for (const auto& pair : other.int_values_) {
|
||||
int_values_[pair.first] = pair.second;
|
||||
}
|
||||
for (const auto& pair : other.str_values_) {
|
||||
str_values_[pair.first] = pair.second;
|
||||
}
|
||||
for (const auto& pair : other.tensor_values_) {
|
||||
tensor_values_[pair.first] = pair.second;
|
||||
}
|
||||
}
|
||||
|
||||
// Stores the in-memory checkpoint to the given writer.
|
||||
Status Save(IteratorStateWriter* writer) const {
|
||||
for (const auto& pair : int_values_) {
|
||||
const auto& key = pair.first;
|
||||
const auto& value = pair.second;
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(key, value));
|
||||
}
|
||||
for (const auto& pair : str_values_) {
|
||||
const auto& key = pair.first;
|
||||
const auto& value = pair.second;
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(key, value));
|
||||
}
|
||||
for (const auto& pair : tensor_values_) {
|
||||
const auto& key = pair.first;
|
||||
const auto& value = pair.second;
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(key, value));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
// Updates the status of the in-memory checkpoint with the given status.
|
||||
void UpdateStatus(Status status) { status_.Update(status); }
|
||||
|
||||
private:
|
||||
Status status_ = OkStatus();
|
||||
absl::flat_hash_map<std::string, int64_t> int_values_;
|
||||
absl::flat_hash_map<std::string, std::string> str_values_;
|
||||
absl::flat_hash_map<std::string, Tensor> tensor_values_;
|
||||
};
|
||||
|
||||
// Aggregates runtime support needed for dataset and iterator serialization.
|
||||
class SerializationContext {
|
||||
public:
|
||||
// Handles the external state according to the external state policy.
|
||||
Status HandleCheckExternalStateStatus(Status s) {
|
||||
if (s.ok()) {
|
||||
return s;
|
||||
}
|
||||
switch (params_.external_state_policy) {
|
||||
case ExternalStatePolicy::POLICY_WARN:
|
||||
LOG(WARNING) << s.ToString();
|
||||
return OkStatus();
|
||||
case ExternalStatePolicy::POLICY_IGNORE:
|
||||
VLOG(2) << "Ignoring error status: " << s.ToString();
|
||||
return OkStatus();
|
||||
case ExternalStatePolicy::POLICY_FAIL:
|
||||
return s;
|
||||
default:
|
||||
return errors::InvalidArgument("Unexpected value of external policy: ",
|
||||
params_.external_state_policy);
|
||||
}
|
||||
}
|
||||
|
||||
struct Params {
|
||||
explicit Params() = default;
|
||||
|
||||
explicit Params(OpKernelContext* ctx)
|
||||
: resource_mgr(ctx->resource_manager()),
|
||||
device_name(ctx->device()->attributes().name()) {}
|
||||
|
||||
std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
|
||||
|
||||
// Indicates what to do if the dataset depends on external state.
|
||||
ExternalStatePolicy external_state_policy =
|
||||
ExternalStatePolicy::POLICY_WARN;
|
||||
|
||||
// Indicates whether the serialization is for rewrites.
|
||||
//
|
||||
// If true:
|
||||
// * A dataset that doesn't implement serialization is replaced with a
|
||||
// placeholder returned in `input_list`.
|
||||
// * Data tensors are replaced with a placeholder returned in
|
||||
// `input_list`.
|
||||
// * Datasets that use random seeds should not serialize the random seeds.
|
||||
// This doesn't affect datasets that use fixed seeds; fixed seeds will
|
||||
// always be preserved.
|
||||
// * Cardinality is serialized as an unregistered attribute
|
||||
// `_cardinality`.
|
||||
// If false:
|
||||
// * A dataset that doesn't implement serialization should result in an
|
||||
// error.
|
||||
// * Data tensors (potentially large) should be serialized.
|
||||
// * Datasets that use random seeds should serialize the random seeds.
|
||||
bool is_graph_rewrite = false;
|
||||
|
||||
// A resource manager for looking up resources during serialization.
|
||||
ResourceMgr* resource_mgr;
|
||||
|
||||
// The name of the device doing the serialization.
|
||||
std::string device_name;
|
||||
|
||||
// Determines whether checkpointing should represent input pipeline state
|
||||
// symbolically, using cursors into source iterators, or explicitly, by
|
||||
// storing internal state of each iterator.
|
||||
bool symbolic_checkpoint = false;
|
||||
};
|
||||
|
||||
explicit SerializationContext(Params params) : params_(params) {}
|
||||
|
||||
std::vector<std::pair<string, Tensor>>* input_list() {
|
||||
return params_.input_list;
|
||||
}
|
||||
|
||||
ExternalStatePolicy external_state_policy() const {
|
||||
return params_.external_state_policy;
|
||||
}
|
||||
|
||||
bool is_graph_rewrite() const { return params_.is_graph_rewrite; }
|
||||
|
||||
const ResourceMgr* resource_mgr() const { return params_.resource_mgr; }
|
||||
|
||||
const std::string& device_name() const { return params_.device_name; }
|
||||
|
||||
bool symbolic_checkpoint() const { return params_.symbolic_checkpoint; }
|
||||
|
||||
private:
|
||||
Params params_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
|
||||
};
|
||||
|
||||
// A cut-down version of `OpKernelContext` for running computations in
|
||||
// iterators. Note that we cannot simply use `OpKernelContext` here because we
|
||||
// might run computation in an iterator whose lifetime is not nested within the
|
||||
|
|
@ -405,12 +622,12 @@ class IteratorContext {
|
|||
interleave_depth(ctx->interleave_depth()),
|
||||
is_restoring(ctx->is_restoring()),
|
||||
model(ctx->model()),
|
||||
options(ctx->options()),
|
||||
resource_mgr(ctx->resource_mgr()),
|
||||
runner(*(ctx->runner())),
|
||||
runner_threadpool_size(ctx->runner_threadpool_size()),
|
||||
split_providers(ctx->split_providers()),
|
||||
stats_aggregator(ctx->stats_aggregator()),
|
||||
symbolic_checkpoint(ctx->symbolic_checkpoint()),
|
||||
thread_factory(ctx->thread_factory()),
|
||||
thread_pool(ctx->thread_pool()) {}
|
||||
|
||||
|
|
@ -495,6 +712,9 @@ class IteratorContext {
|
|||
// using C++ based implementation for tf.data options (on 4/12/2021).
|
||||
std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
|
||||
|
||||
// Indicates whether to use symbolic checkpointing.
|
||||
bool symbolic_checkpoint = false;
|
||||
|
||||
// A factory for creating threads to perform blocking work.
|
||||
std::shared_ptr<ThreadFactory> thread_factory = nullptr;
|
||||
|
||||
|
|
@ -532,14 +752,14 @@ class IteratorContext {
|
|||
return params_.function_handle_cache;
|
||||
}
|
||||
|
||||
const MemoryCheckpoint& checkpoint() const { return checkpoint_; }
|
||||
|
||||
int64 interleave_depth() { return params_.interleave_depth; }
|
||||
|
||||
bool is_restoring() { return params_.is_restoring; }
|
||||
|
||||
const std::shared_ptr<model::Model>& model() { return params_.model; }
|
||||
|
||||
const Options* options() { return params_.options; }
|
||||
|
||||
ResourceMgr* resource_mgr() { return params_.resource_mgr; }
|
||||
|
||||
std::function<void(std::function<void()>)>* runner() {
|
||||
|
|
@ -556,6 +776,8 @@ class IteratorContext {
|
|||
return params_.stats_aggregator;
|
||||
}
|
||||
|
||||
bool symbolic_checkpoint() { return params_.symbolic_checkpoint; }
|
||||
|
||||
const std::shared_ptr<ThreadFactory>& thread_factory() {
|
||||
return params_.thread_factory;
|
||||
}
|
||||
|
|
@ -577,6 +799,40 @@ class IteratorContext {
|
|||
}
|
||||
}
|
||||
|
||||
// Merges the given checkpoint with the checkpoint of this context.
|
||||
//
|
||||
// The intended for this API is that methods, such as
|
||||
// `IteratorBase::Initialize`, `IteratorBase::GetNextInternal`, or
|
||||
// `IteratorBase::RestoreInternal` that store data in the in-memory
|
||||
// checkpoint, use a separate instance of `IteratorContext` for a nested call,
|
||||
// then the checkpoint collected by the `IteratorContext` instance passed into
|
||||
// the callee should be merged into the `IteratorContext` of the caller:
|
||||
//
|
||||
// ```
|
||||
// Status GetNextInternal(IteratorContext* ctx, ...) {
|
||||
// ...
|
||||
// IteratorContext nested_ctx(...);
|
||||
// TF_RETURN_IF_ERROR(input_impl_->GetNext(&nested_ctx, ...));
|
||||
// ctx->MergeCheckpoint(nested_ctx->checkpoint());
|
||||
// ...
|
||||
// }
|
||||
// ```
|
||||
void MergeCheckpoint(const MemoryCheckpoint& checkpoint) {
|
||||
if (symbolic_checkpoint()) {
|
||||
checkpoint_.Merge(checkpoint);
|
||||
}
|
||||
}
|
||||
|
||||
// Saves the state of the given iterator into the checkpoint.
|
||||
void SaveCheckpoint(Checkpointable* iterator) {
|
||||
if (symbolic_checkpoint()) {
|
||||
SerializationContext::Params params;
|
||||
params.symbolic_checkpoint = true;
|
||||
SerializationContext ctx(std::move(params));
|
||||
checkpoint_.UpdateStatus(iterator->Save(&ctx, &checkpoint_));
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Thread> StartThread(const string& name,
|
||||
std::function<void()> fn) {
|
||||
if (params_.thread_factory) {
|
||||
|
|
@ -587,98 +843,22 @@ class IteratorContext {
|
|||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Params params_;
|
||||
};
|
||||
|
||||
// Aggregates runtime support needed for dataset and iterator serialization.
|
||||
class SerializationContext {
|
||||
public:
|
||||
// Handles the external state according to the external state policy.
|
||||
Status HandleCheckExternalStateStatus(Status s) {
|
||||
if (s.ok()) {
|
||||
return s;
|
||||
}
|
||||
switch (params_.external_state_policy) {
|
||||
case ExternalStatePolicy::POLICY_WARN:
|
||||
LOG(WARNING) << s.ToString();
|
||||
return OkStatus();
|
||||
case ExternalStatePolicy::POLICY_IGNORE:
|
||||
VLOG(2) << "Ignoring error status: " << s.ToString();
|
||||
return OkStatus();
|
||||
case ExternalStatePolicy::POLICY_FAIL:
|
||||
return s;
|
||||
default:
|
||||
return errors::InvalidArgument("Unexpected value of external policy: ",
|
||||
params_.external_state_policy);
|
||||
// Updates the status of the checkpoint with the given status.
|
||||
void UpdateCheckpointStatus(std::function<Status()> status_fn) {
|
||||
if (symbolic_checkpoint()) {
|
||||
checkpoint_.UpdateStatus(status_fn());
|
||||
}
|
||||
}
|
||||
|
||||
struct Params {
|
||||
explicit Params() {}
|
||||
|
||||
explicit Params(OpKernelContext* ctx)
|
||||
: resource_mgr(ctx->resource_manager()),
|
||||
device_name(ctx->device()->attributes().name()) {}
|
||||
|
||||
std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
|
||||
|
||||
// Indicates what to do if the dataset depends on external state.
|
||||
ExternalStatePolicy external_state_policy =
|
||||
ExternalStatePolicy::POLICY_WARN;
|
||||
|
||||
// Indicates whether the serialization is for rewrites.
|
||||
//
|
||||
// If true:
|
||||
// * A dataset that doesn't implement serialization is replaced with a
|
||||
// placeholder returned in `input_list`.
|
||||
// * Data tensors are replaced with a placeholder returned in
|
||||
// `input_list`.
|
||||
// * Datasets that use random seeds should not serialize the random seeds.
|
||||
// This doesn't affect datasets that use fixed seeds; fixed seeds will
|
||||
// always be preserved.
|
||||
// * Cardinality is serialized as an unregistered attribute
|
||||
// `_cardinality`.
|
||||
// If false:
|
||||
// * A dataset that doesn't implement serialization should result in an
|
||||
// error.
|
||||
// * Data tensors (potentially large) should be serialized.
|
||||
// * Datasets that use random seeds should serialize the random seeds.
|
||||
bool is_graph_rewrite = false;
|
||||
|
||||
// A resource manager for looking up resources during serialization.
|
||||
ResourceMgr* resource_mgr;
|
||||
|
||||
// The name of the device doing the serialization.
|
||||
std::string device_name;
|
||||
};
|
||||
|
||||
explicit SerializationContext(Params params) : params_(params) {}
|
||||
|
||||
std::vector<std::pair<string, Tensor>>* input_list() {
|
||||
return params_.input_list;
|
||||
}
|
||||
|
||||
ExternalStatePolicy external_state_policy() const {
|
||||
return params_.external_state_policy;
|
||||
}
|
||||
|
||||
bool is_graph_rewrite() const { return params_.is_graph_rewrite; }
|
||||
|
||||
const ResourceMgr* resource_mgr() const { return params_.resource_mgr; }
|
||||
|
||||
const std::string& device_name() const { return params_.device_name; }
|
||||
|
||||
private:
|
||||
Params params_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
|
||||
MemoryCheckpoint checkpoint_;
|
||||
};
|
||||
|
||||
// Represents the current position in a range of outputs, where the
|
||||
// range of outputs is typically represented by an `DatasetBase`,
|
||||
// defined below.
|
||||
class IteratorBase {
|
||||
class IteratorBase : public Checkpointable {
|
||||
public:
|
||||
virtual ~IteratorBase() {
|
||||
for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
|
||||
|
|
@ -750,6 +930,9 @@ class IteratorBase {
|
|||
// this iterator.
|
||||
virtual const string& prefix() const = 0;
|
||||
|
||||
// Indicates whether the iterator is compatible with symbolic checkpointing.
|
||||
virtual bool SymbolicCheckpointCompatible() const { return false; }
|
||||
|
||||
// Performs initialization that needs to happen outside of a constructor to
|
||||
// properly propagate errors.
|
||||
virtual Status Initialize(IteratorContext* ctx) { return OkStatus(); }
|
||||
|
|
@ -758,7 +941,7 @@ class IteratorBase {
|
|||
Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent);
|
||||
|
||||
// Saves the state of this iterator.
|
||||
virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
|
||||
Status Save(SerializationContext* ctx, IteratorStateWriter* writer) override {
|
||||
int64_t start_us = EnvTime::NowMicros();
|
||||
TF_RETURN_IF_ERROR(SaveInternal(ctx, writer));
|
||||
VLOG(1) << "Saved " << prefix() << " in "
|
||||
|
|
@ -766,24 +949,28 @@ class IteratorBase {
|
|||
return OkStatus();
|
||||
}
|
||||
|
||||
// Restores the state of this iterator.
|
||||
Status Restore(IteratorContext* ctx, IteratorStateReader* reader) override {
|
||||
int64_t start_us = EnvTime::NowMicros();
|
||||
TF_RETURN_IF_ERROR(RestoreInternal(ctx, reader));
|
||||
ctx->SaveCheckpoint(this);
|
||||
VLOG(1) << "Restored " << prefix() << " in "
|
||||
<< (EnvTime::NowMicros() - start_us) << "us";
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
protected:
|
||||
// Returns a node that models this iterator.
|
||||
virtual std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const = 0;
|
||||
|
||||
// Restores the state of this iterator.
|
||||
virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
|
||||
int64_t start_us = EnvTime::NowMicros();
|
||||
TF_RETURN_IF_ERROR(RestoreInternal(ctx, reader));
|
||||
VLOG(1) << "Restored " << prefix() << " in "
|
||||
<< (EnvTime::NowMicros() - start_us) << "us";
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
// This is needed so that sub-classes of IteratorBase can call
|
||||
// `SaveInternal` on their input iterators.
|
||||
Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer,
|
||||
const std::unique_ptr<IteratorBase>& input) {
|
||||
if (ctx->symbolic_checkpoint()) {
|
||||
return OkStatus();
|
||||
}
|
||||
return input->Save(ctx, writer);
|
||||
}
|
||||
|
||||
|
|
@ -947,6 +1134,7 @@ class DatasetBase : public core::RefCounted {
|
|||
TF_RETURN_IF_ERROR(MakeIterator(&restore_ctx,
|
||||
/*parent=*/nullptr, output_prefix, &it));
|
||||
TF_RETURN_IF_ERROR(it->Restore(&restore_ctx, reader));
|
||||
ctx->MergeCheckpoint(restore_ctx.checkpoint());
|
||||
*iterator = std::move(it);
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ enum ExternalStatePolicy {
|
|||
// Message stored with Dataset objects to control how datasets are processed and
|
||||
// optimized.
|
||||
//
|
||||
// next: 8
|
||||
// next: 9
|
||||
message Options {
|
||||
// Whether the outputs need to be produced in deterministic order.
|
||||
oneof optional_deterministic {
|
||||
|
|
@ -202,4 +202,18 @@ message Options {
|
|||
oneof optional_external_state_policy {
|
||||
ExternalStatePolicy external_state_policy = 6;
|
||||
}
|
||||
// This option indicates whether to checkpoint input pipeline state
|
||||
// "explicitly", by storing the internal state of iterators for each
|
||||
// tf.data operation, (the default), or "symbolically", by storing metadata
|
||||
// that captures the state of each tf.data operation at the time it processed
|
||||
// the last data seen by tf.data consumer.
|
||||
//
|
||||
// Symbolic checkpoints are expected to be much smaller but not all tf.data
|
||||
// operations are compatible with symbolic checkpointing. In particular,
|
||||
// symbolic checkpointing requires that data is processed in-order and
|
||||
// operations that reorder elements, such as `shuffle()` or non-deterministic
|
||||
// `map()`, are not compatible with symbolic checkpointing.
|
||||
oneof optional_symbolic_checkpoint {
|
||||
bool symbolic_checkpoint = 8;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -434,6 +434,7 @@ tf_kernel_library(
|
|||
"//tensorflow/core/kernels:ops_util",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/profiler/lib:traceme_encode",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -180,6 +180,8 @@ class BatchDatasetOp::Dataset : public DatasetBase {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||
}
|
||||
|
|
@ -246,9 +248,9 @@ class BatchDatasetOp::Dataset : public DatasetBase {
|
|||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kInputImplEmpty), static_cast<int64_t>(!input_impl_)));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
}
|
||||
return OkStatus();
|
||||
|
|
@ -257,7 +259,10 @@ class BatchDatasetOp::Dataset : public DatasetBase {
|
|||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!reader->Contains(full_name(kInputImplEmpty))) {
|
||||
int64_t input_empty;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kInputImplEmpty), &input_empty));
|
||||
if (!static_cast<bool>(input_empty)) {
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
} else {
|
||||
input_impl_.reset();
|
||||
|
|
|
|||
|
|
@ -158,12 +158,16 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params), i_(0) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
TF_ASSIGN_OR_RETURN(input_contexts_,
|
||||
CreateInputIteratorContexts(ctx, dataset()));
|
||||
return dataset()->input_->MakeIterator(&input_contexts_[0], this,
|
||||
strings::StrCat(prefix(), "[0]"),
|
||||
&input_impl_);
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
&input_contexts_[0], this, strings::StrCat(prefix(), "[0]"),
|
||||
&input_impl_));
|
||||
ctx->MergeCheckpoint(input_contexts_[0].checkpoint());
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
|
|
@ -177,6 +181,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase {
|
|||
while (i_ < 2) {
|
||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(&input_contexts_[i_],
|
||||
out_tensors, end_of_sequence));
|
||||
ctx->MergeCheckpoint(input_contexts_[i_].checkpoint());
|
||||
if (!*end_of_sequence) {
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -202,11 +207,11 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase {
|
|||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), i_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kInputImplUninitialized),
|
||||
static_cast<int64_t>(!input_impl_)));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kInputImplUninitialized), ""));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -215,7 +220,10 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase {
|
|||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &i_));
|
||||
if (reader->Contains(full_name(kInputImplUninitialized))) {
|
||||
int64_t input_uninitialized;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kInputImplUninitialized),
|
||||
&input_uninitialized));
|
||||
if (static_cast<bool>(input_uninitialized)) {
|
||||
input_impl_.reset();
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,6 +96,8 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params), num_elements_(0) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
|
@ -43,6 +44,8 @@ namespace experimental {
|
|||
DirectedInterleaveDatasetOp::kNumInputDatasets;
|
||||
|
||||
constexpr char kCycleLength[] = "cycle_length";
|
||||
constexpr char kDataInputImplEmpty[] = "data_input_impl_empty";
|
||||
constexpr char kSelectorInputImplEmpty[] = "selector_input_impl_empty";
|
||||
|
||||
class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
|
|
@ -158,18 +161,22 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
: DatasetIterator<Dataset>(params),
|
||||
num_active_inputs_(params.dataset->data_inputs_.size()) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_ASSIGN_OR_RETURN(input_contexts_,
|
||||
CreateInputIteratorContexts(ctx, dataset()));
|
||||
TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
|
||||
&input_contexts_[0], this, prefix(), &selector_input_impl_));
|
||||
ctx->MergeCheckpoint(input_contexts_[0].checkpoint());
|
||||
data_input_impls_.resize(dataset()->data_inputs_.size());
|
||||
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
|
||||
const DatasetBase* data_input = dataset()->data_inputs_[i];
|
||||
TF_RETURN_IF_ERROR(data_input->MakeIterator(
|
||||
&input_contexts_[i + 1], this,
|
||||
strings::StrCat(prefix(), "[", i, "]"), &data_input_impls_[i]));
|
||||
ctx->MergeCheckpoint(input_contexts_[i + 1].checkpoint());
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -188,6 +195,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
*end_of_sequence = false;
|
||||
TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
|
||||
&input_contexts_[0], &selector_result, end_of_sequence));
|
||||
ctx->MergeCheckpoint(input_contexts_[0].checkpoint());
|
||||
if (*end_of_sequence) {
|
||||
ResetInputs();
|
||||
return OkStatus();
|
||||
|
|
@ -205,7 +213,8 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
|
||||
&input_contexts_[selected_input + 1], out_tensors,
|
||||
&end_of_selected_input));
|
||||
|
||||
ctx->MergeCheckpoint(
|
||||
input_contexts_[selected_input + 1].checkpoint());
|
||||
if (!end_of_selected_input) {
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -242,20 +251,19 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kSelectorInputImplEmpty),
|
||||
static_cast<int64_t>(!selector_input_impl_)));
|
||||
if (selector_input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
|
||||
}
|
||||
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
|
||||
const auto& data_input_impl = data_input_impls_[i];
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(kDataInputImplEmpty, "[", i, "]")),
|
||||
static_cast<int64_t>(!data_input_impl)));
|
||||
if (data_input_impl) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, data_input_impl));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
|
||||
""));
|
||||
}
|
||||
}
|
||||
return OkStatus();
|
||||
|
|
@ -264,14 +272,19 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!reader->Contains(full_name("selector_input_impl_empty"))) {
|
||||
int64_t input_empty;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kSelectorInputImplEmpty), &input_empty));
|
||||
if (!static_cast<bool>(input_empty)) {
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
|
||||
} else {
|
||||
selector_input_impl_.reset();
|
||||
}
|
||||
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
|
||||
if (!reader->Contains(
|
||||
full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) {
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat(kDataInputImplEmpty, "[", i, "]")),
|
||||
&input_empty));
|
||||
if (!static_cast<bool>(input_empty)) {
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
|
||||
} else {
|
||||
data_input_impls_[i].reset();
|
||||
|
|
|
|||
|
|
@ -130,6 +130,8 @@ class ListDatasetOp::Dataset : public DatasetBase {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
if (ctx->split_providers().empty()) {
|
||||
split_provider_ =
|
||||
|
|
|
|||
|
|
@ -106,6 +106,8 @@ class RandomDatasetOp::Dataset : public DatasetBase {
|
|||
parent_generator_(seeds_.first, seeds_.second),
|
||||
generator_(&parent_generator_) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ namespace data {
|
|||
namespace experimental {
|
||||
namespace {
|
||||
|
||||
constexpr char kInputImplEmpty[] = "input_impl_empty";
|
||||
|
||||
class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit TakeWhileDatasetOp(OpKernelConstruction* ctx)
|
||||
|
|
@ -125,6 +127,8 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
|
||||
|
|
@ -179,11 +183,10 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
|
|||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kInputImplEmpty), static_cast<int64_t>(!input_impl_)));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impls_empty"), ""));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -191,7 +194,10 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
|
|||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
if (reader->Contains(full_name("input_impls_empty"))) {
|
||||
int64_t input_empty;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kInputImplEmpty), &input_empty));
|
||||
if (static_cast<bool>(input_empty)) {
|
||||
input_impl_.reset();
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ namespace data {
|
|||
/* static */ constexpr const char* const FilterDatasetOp::kOutputTypes;
|
||||
/* static */ constexpr const char* const FilterDatasetOp::kOutputShapes;
|
||||
|
||||
constexpr char kInputImplsEmpty[] = "input_impls_empty";
|
||||
constexpr char kInputImplEmpty[] = "input_impl_empty";
|
||||
constexpr char kFilteredElements[] = "filtered_elements";
|
||||
constexpr char kDroppedElements[] = "dropped_elements";
|
||||
|
||||
|
|
@ -114,6 +114,8 @@ class FilterDatasetOp::Dataset : public DatasetBase {
|
|||
filtered_elements_(0),
|
||||
dropped_elements_(0) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
|
||||
|
|
@ -121,13 +123,12 @@ class FilterDatasetOp::Dataset : public DatasetBase {
|
|||
ctx, &instantiated_captured_func_);
|
||||
}
|
||||
|
||||
// NOTE(mrry): This method is thread-safe as long as `input_impl_` and `f`
|
||||
// are thread-safe. However, if multiple threads enter this method,
|
||||
// outputs may be observed in a non-deterministic order.
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
// NOTE(mrry): This method is thread-safe as long as
|
||||
// `input_impl_` and `f` are thread-safe. However, if multiple
|
||||
// threads enter this method, outputs may be observed in a
|
||||
// non-deterministic order.
|
||||
auto stats_aggregator = ctx->stats_aggregator();
|
||||
bool matched;
|
||||
do {
|
||||
|
|
@ -204,11 +205,11 @@ class FilterDatasetOp::Dataset : public DatasetBase {
|
|||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_)
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kInputImplEmpty), static_cast<int64_t>(!input_impl_)));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
else
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kInputImplsEmpty), ""));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kFilteredElements),
|
||||
filtered_elements_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
|
@ -219,10 +220,14 @@ class FilterDatasetOp::Dataset : public DatasetBase {
|
|||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
if (reader->Contains(full_name(kInputImplsEmpty)))
|
||||
int64_t input_empty;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kInputImplEmpty), &input_empty));
|
||||
if (static_cast<bool>(input_empty)) {
|
||||
input_impl_.reset();
|
||||
else
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kFilteredElements),
|
||||
&filtered_elements_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
|
|
|||
|
|
@ -125,6 +125,8 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
|
||||
|
|
@ -145,8 +147,10 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
|||
// We are currently processing a mapped element, so try to get the
|
||||
// next subelement.
|
||||
bool end_of_element;
|
||||
auto nested_ctx = MakeNestedIteratorContext(ctx);
|
||||
TF_RETURN_IF_ERROR(current_element_iterator_->GetNext(
|
||||
MakeNestedIteratorContext(ctx), out_tensors, &end_of_element));
|
||||
&nested_ctx, out_tensors, &end_of_element));
|
||||
ctx->MergeCheckpoint(nested_ctx.checkpoint());
|
||||
if (!end_of_element) {
|
||||
// Produce the subelement as output.
|
||||
*end_of_sequence = false;
|
||||
|
|
@ -223,10 +227,15 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
|||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kExhausted), static_cast<int64_t>(!input_impl_)));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kElementIndex), element_index_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kCurrentElementIteratorUninitialized),
|
||||
static_cast<int64_t>(!current_element_iterator_)));
|
||||
if (current_element_iterator_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kInputsSize), inputs_.size()));
|
||||
|
|
@ -235,12 +244,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
|||
full_name(strings::StrCat(kInputs, "[", i, "]")), inputs_[i]));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_element_iterator_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kCurrentElementIteratorUninitialized), ""));
|
||||
}
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kExhausted), ""));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -252,7 +256,10 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
|||
element_index_ = 0;
|
||||
current_element_iterator_.reset();
|
||||
inputs_.clear();
|
||||
if (!reader->Contains(full_name(kExhausted))) {
|
||||
int64_t input_exhausted;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kExhausted), &input_exhausted));
|
||||
if (!static_cast<bool>(input_exhausted)) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
|
|
@ -262,8 +269,11 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
|||
reader->ReadScalar(full_name(kElementIndex), &temp));
|
||||
element_index_ = temp;
|
||||
}
|
||||
if (!reader->Contains(
|
||||
full_name(kCurrentElementIteratorUninitialized))) {
|
||||
int64_t current_element_iterator_uninitialized;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kCurrentElementIteratorUninitialized),
|
||||
¤t_element_iterator_uninitialized));
|
||||
if (!static_cast<bool>(current_element_iterator_uninitialized)) {
|
||||
size_t inputs_size;
|
||||
{
|
||||
int64_t temp;
|
||||
|
|
@ -293,17 +303,11 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
|||
Status BuildCurrentElementIteratorLocked(IteratorContext* ctx,
|
||||
bool is_get_next)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (is_get_next) {
|
||||
return MakeIteratorFromInputElement(
|
||||
ctx, this, inputs_, element_index_++, *instantiated_captured_func_,
|
||||
prefix(), ¤t_element_iterator_, model_node());
|
||||
} else {
|
||||
// NOTE: We intentionally ignore resource modeling outside GetNext().
|
||||
std::shared_ptr<model::Node> node = is_get_next ? model_node() : nullptr;
|
||||
return MakeIteratorFromInputElement(
|
||||
ctx, this, inputs_, element_index_++, *instantiated_captured_func_,
|
||||
prefix(), ¤t_element_iterator_,
|
||||
/*node=*/nullptr);
|
||||
}
|
||||
prefix(), ¤t_element_iterator_, node);
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ constexpr char kEndOfInput[] = "end_of_input";
|
|||
constexpr char kNumOpen[] = "num_open";
|
||||
constexpr char kArgsSize[] = "args_size";
|
||||
constexpr char kArgsList[] = "args_list_";
|
||||
constexpr char kCurrentElementsUnitialized[] = "current_elements_uninitialized";
|
||||
|
||||
class InterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
|
|
@ -131,6 +132,8 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
current_elements_(params.dataset->cycle_length_),
|
||||
args_list_(params.dataset->cycle_length_) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
|
||||
|
|
@ -255,9 +258,8 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
writer->WriteScalar(full_name(kCycleIndex), cycle_index_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kBlockIndex), block_index_));
|
||||
if (end_of_input_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kEndOfInput), ""));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kEndOfInput), static_cast<int64_t>(end_of_input_)));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumOpen), num_open_));
|
||||
TF_RETURN_IF_ERROR(SaveCurrentElements(ctx, writer));
|
||||
return OkStatus();
|
||||
|
|
@ -273,7 +275,10 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
cycle_index_ = size_t(cycle_index);
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kBlockIndex), &block_index_));
|
||||
if (reader->Contains(full_name(kEndOfInput))) end_of_input_ = true;
|
||||
int64_t end_of_input;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kEndOfInput), &end_of_input));
|
||||
end_of_input_ = static_cast<bool>(end_of_input);
|
||||
int64_t num_open;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumOpen), &num_open));
|
||||
num_open_ = size_t(num_open);
|
||||
|
|
@ -290,6 +295,10 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
IteratorStateWriter* writer)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
for (int idx = 0; idx < current_elements_.size(); idx++) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(
|
||||
strings::StrCat(kCurrentElementsUnitialized, "[", idx, "]")),
|
||||
!current_elements_[idx]));
|
||||
if (current_elements_[idx]) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_elements_[idx]));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
|
|
@ -309,8 +318,12 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
|||
IteratorStateReader* reader)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
for (int idx = 0; idx < current_elements_.size(); idx++) {
|
||||
if (reader->Contains(
|
||||
full_name(strings::StrCat(kArgsSize, "[", idx, "]")))) {
|
||||
int64_t current_element_uninitialized;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(strings::StrCat(
|
||||
kCurrentElementsUnitialized, "[", idx, "]")),
|
||||
¤t_element_uninitialized));
|
||||
if (!current_element_uninitialized) {
|
||||
int64_t args_size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat(kArgsSize, "[", idx, "]")),
|
||||
|
|
|
|||
|
|
@ -65,6 +65,12 @@ const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
|
|||
const char kOutputShapes[] = "output_shapes";
|
||||
const char kOutputTypes[] = "output_types";
|
||||
|
||||
bool SymbolicCheckpointEnabled(const Options& options) {
|
||||
return options.optional_symbolic_checkpoint_case() ==
|
||||
Options::kSymbolicCheckpoint &&
|
||||
options.symbolic_checkpoint();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ constexpr const char* const
|
||||
|
|
@ -100,49 +106,65 @@ Status IteratorResource::GetNext(OpKernelContext* ctx,
|
|||
tf_shared_lock l(mu_);
|
||||
captured_state = iterator_state_;
|
||||
}
|
||||
if (!captured_state->iterator()) {
|
||||
auto iterator = captured_state->iterator();
|
||||
if (!iterator) {
|
||||
return errors::FailedPrecondition(
|
||||
"GetNext() failed because the iterator has not been initialized. "
|
||||
"Ensure that you have run the initializer operation for this iterator "
|
||||
"before getting the next element.");
|
||||
}
|
||||
auto* dataset = captured_state->dataset();
|
||||
IteratorContext::Params params(ctx);
|
||||
params.cancellation_manager = captured_state->cancellation_manager();
|
||||
params.flr = captured_state->flr();
|
||||
params.function_handle_cache = captured_state->function_handle_cache();
|
||||
params.resource_mgr = captured_state->resource_mgr();
|
||||
params.symbolic_checkpoint = SymbolicCheckpointEnabled(dataset->options());
|
||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
params.thread_pool = &unbounded_thread_pool_;
|
||||
params.cancellation_manager = captured_state->cancellation_manager();
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
const absl::Time start_time = metrics_collector_.RecordStart();
|
||||
auto iterator_ = captured_state->iterator();
|
||||
auto status = iterator_->GetNext(IteratorContext(std::move(params)),
|
||||
out_tensors, end_of_sequence);
|
||||
auto status = iterator->GetNext(&iter_ctx, out_tensors, end_of_sequence);
|
||||
metrics_collector_.RecordStop(start_time, *out_tensors);
|
||||
captured_state->MergeCheckpoint(iter_ctx.checkpoint());
|
||||
return status;
|
||||
}
|
||||
|
||||
Status IteratorResource::Save(SerializationContext* ctx,
|
||||
Status IteratorResource::Save(OpKernelContext* ctx,
|
||||
ExternalStatePolicy external_state_policy,
|
||||
IteratorStateWriter* writer) {
|
||||
std::shared_ptr<State> captured_state;
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
captured_state = iterator_state_;
|
||||
}
|
||||
auto iterator_ = captured_state->iterator();
|
||||
if (iterator_) {
|
||||
return iterator_->Save(ctx, writer);
|
||||
}
|
||||
auto iterator = captured_state->iterator();
|
||||
if (!iterator) {
|
||||
return errors::FailedPrecondition(
|
||||
"Save() failed because the iterator has not been initialized. Ensure "
|
||||
"that you have run the initializer operation for this iterator before "
|
||||
"saving it.");
|
||||
}
|
||||
auto* dataset = captured_state->dataset();
|
||||
if (SymbolicCheckpointEnabled(dataset->options())) {
|
||||
const auto& checkpoint = captured_state->checkpoint();
|
||||
if (!checkpoint.GetStatus().ok()) {
|
||||
return checkpoint.GetStatus();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(checkpoint.Save(writer));
|
||||
return OkStatus();
|
||||
}
|
||||
SerializationContext::Params params(ctx);
|
||||
params.external_state_policy = external_state_policy;
|
||||
params.symbolic_checkpoint = SymbolicCheckpointEnabled(dataset->options());
|
||||
SerializationContext serialization_ctx(params);
|
||||
return iterator->Save(&serialization_ctx, writer);
|
||||
}
|
||||
|
||||
Status IteratorResource::Restore(OpKernelContext* ctx,
|
||||
|
|
@ -152,14 +174,14 @@ Status IteratorResource::Restore(OpKernelContext* ctx,
|
|||
const DatasetBase* input_dataset;
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
if (!iterator_state_->iterator()) {
|
||||
auto iterator = iterator_state_->iterator();
|
||||
if (!iterator) {
|
||||
return errors::FailedPrecondition(
|
||||
"Restore() failed because the iterator has not been initialized. "
|
||||
"Ensure that you have run the initializer operation for this "
|
||||
"iterator before restoring it.");
|
||||
}
|
||||
auto iterator_ = iterator_state_->iterator();
|
||||
dataset = iterator_->dataset();
|
||||
dataset = iterator->dataset();
|
||||
// Hang onto a reference until we've created the new iterator, which will
|
||||
// then hold its own reference to keep the dataset alive.
|
||||
dataset->Ref();
|
||||
|
|
@ -171,24 +193,27 @@ Status IteratorResource::Restore(OpKernelContext* ctx,
|
|||
}
|
||||
core::ScopedUnref scoped_unref(dataset);
|
||||
IteratorContext::Params params(ctx);
|
||||
params.cancellation_manager = new_state->cancellation_manager();
|
||||
params.flr = new_state->flr();
|
||||
params.function_handle_cache = new_state->function_handle_cache();
|
||||
params.resource_mgr = new_state->resource_mgr();
|
||||
params.symbolic_checkpoint =
|
||||
SymbolicCheckpointEnabled(input_dataset->options());
|
||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
params.thread_pool = &unbounded_thread_pool_;
|
||||
params.cancellation_manager = new_state->cancellation_manager();
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
IteratorContext iter_ctx(IteratorContext(std::move(params)));
|
||||
std::unique_ptr<IteratorBase> iterator_base;
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIteratorFromCheckpoint(
|
||||
IteratorContext(std::move(params)), "Iterator", reader, &iterator_base));
|
||||
&iter_ctx, "Iterator", reader, &iterator_base));
|
||||
new_state->DowncastAndSetIteratorAndDataset(std::move(iterator_base),
|
||||
input_dataset);
|
||||
|
||||
new_state->MergeCheckpoint(iter_ctx.checkpoint());
|
||||
mutex_lock l(mu_);
|
||||
std::swap(iterator_state_, new_state);
|
||||
return OkStatus();
|
||||
|
|
@ -207,28 +232,29 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
|||
|
||||
// Create new iterator.
|
||||
IteratorContext::Params params(ctx);
|
||||
params.cancellation_manager = new_state->cancellation_manager();
|
||||
params.flr = new_state->flr();
|
||||
params.function_handle_cache = new_state->function_handle_cache();
|
||||
params.resource_mgr = new_state->resource_mgr();
|
||||
params.symbolic_checkpoint = SymbolicCheckpointEnabled(dataset->options());
|
||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
params.thread_pool = &unbounded_thread_pool_;
|
||||
params.cancellation_manager = new_state->cancellation_manager();
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
|
||||
IteratorContext iter_ctx(IteratorContext(std::move(params)));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
if (ctx->function_library()->device()->device_type() == DEVICE_CPU) {
|
||||
DatasetBase* finalized_dataset;
|
||||
TF_ASSIGN_OR_RETURN(finalized_dataset, GetFinalizedDataset(ctx, dataset));
|
||||
TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
|
||||
IteratorContext(std::move(params)),
|
||||
/*parent=*/nullptr, "Iterator", &iterator));
|
||||
TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(&iter_ctx,
|
||||
/*parent=*/nullptr,
|
||||
"Iterator", &iterator));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx,
|
||||
/*parent=*/nullptr, "Iterator",
|
||||
&iterator));
|
||||
}
|
||||
|
|
@ -236,14 +262,28 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
|||
VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
|
||||
|
||||
new_state->DowncastAndSetIteratorAndDataset(std::move(iterator), dataset);
|
||||
|
||||
new_state->MergeCheckpoint(iter_ctx.checkpoint());
|
||||
mutex_lock l(mu_);
|
||||
std::swap(iterator_state_, new_state);
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
void IteratorResource::State::DowncastAndSetIteratorAndDataset(
|
||||
std::unique_ptr<IteratorBase> it, const DatasetBase* dataset) {
|
||||
iterator_.reset(static_cast<DatasetBaseIterator*>(it.release()));
|
||||
if (dataset) {
|
||||
dataset->Ref();
|
||||
dataset_.reset(const_cast<DatasetBase*>(dataset));
|
||||
}
|
||||
}
|
||||
|
||||
void IteratorResource::State::MergeCheckpoint(const MemoryCheckpoint& other) {
|
||||
if (SymbolicCheckpointEnabled(dataset_->options())) {
|
||||
checkpoint_.Merge(other);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// A helper class that uses a list of IteratorStateVariant objects to represent
|
||||
|
|
@ -266,10 +306,12 @@ class IteratorVariantSerializer {
|
|||
|
||||
// Calls `Save` on the iterator_resource to build up the list of
|
||||
// IteratorStateVariant objects.
|
||||
Status InitializeFromIterator(SerializationContext* serialization_ctx,
|
||||
Status InitializeFromIterator(OpKernelContext* ctx,
|
||||
ExternalStatePolicy external_state_policy,
|
||||
IteratorResource* iterator_resource) {
|
||||
VariantTensorDataWriter writer;
|
||||
TF_RETURN_IF_ERROR(iterator_resource->Save(serialization_ctx, &writer));
|
||||
TF_RETURN_IF_ERROR(
|
||||
iterator_resource->Save(ctx, external_state_policy, &writer));
|
||||
std::vector<std::unique_ptr<VariantTensorData>> data;
|
||||
writer.ReleaseData(&data);
|
||||
variants_.clear();
|
||||
|
|
@ -989,11 +1031,8 @@ void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
|
|||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
||||
core::ScopedUnref unref_iterator(iterator_resource);
|
||||
IteratorVariantSerializer serializer;
|
||||
SerializationContext::Params params(ctx);
|
||||
params.external_state_policy = external_state_policy_;
|
||||
SerializationContext serialization_ctx(params);
|
||||
OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
|
||||
iterator_resource));
|
||||
OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(
|
||||
ctx, external_state_policy_, iterator_resource));
|
||||
Tensor* serialized_t;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
|
||||
|
|
|
|||
|
|
@ -55,7 +55,8 @@ class IteratorResource : public ResourceBase {
|
|||
bool* end_of_sequence);
|
||||
|
||||
// Saves a checkpoint of the state of the iterator through the given `writer`.
|
||||
Status Save(SerializationContext* ctx, IteratorStateWriter* writer);
|
||||
Status Save(OpKernelContext* ctx, ExternalStatePolicy external_state_policy,
|
||||
IteratorStateWriter* writer);
|
||||
|
||||
// Restores the state of the iterator from a checkpoint created by `Save`.
|
||||
Status Restore(OpKernelContext* ctx, IteratorStateReader* reader);
|
||||
|
|
@ -91,17 +92,6 @@ class IteratorResource : public ResourceBase {
|
|||
|
||||
~State() { cancellation_manager_.StartCancel(); }
|
||||
|
||||
// Downcasts the given `IteratorBase` to a `DatasetBaseIterator`, and uses
|
||||
// it to set the `iterator` and the `dataset` field.
|
||||
void DowncastAndSetIteratorAndDataset(std::unique_ptr<IteratorBase> it,
|
||||
const DatasetBase* dataset) {
|
||||
iterator_.reset(static_cast<DatasetBaseIterator*>(it.release()));
|
||||
if (dataset) {
|
||||
dataset->Ref();
|
||||
dataset_.reset(const_cast<DatasetBase*>(dataset));
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<FunctionLibraryDefinition> flib_def() { return flib_def_; }
|
||||
|
||||
FunctionLibraryRuntime* flr() { return flr_; }
|
||||
|
|
@ -120,8 +110,18 @@ class IteratorResource : public ResourceBase {
|
|||
|
||||
DatasetBaseIterator* iterator() { return iterator_.get(); }
|
||||
|
||||
const MemoryCheckpoint& checkpoint() const { return checkpoint_; }
|
||||
|
||||
DatasetBase* dataset() { return dataset_.get(); }
|
||||
|
||||
// Downcasts the given `IteratorBase` to a `DatasetBaseIterator`, and uses
|
||||
// it to set the `iterator` and the `dataset` field.
|
||||
void DowncastAndSetIteratorAndDataset(std::unique_ptr<IteratorBase> it,
|
||||
const DatasetBase* dataset);
|
||||
|
||||
// Merges the given checkpoint with the checkpoint of this state.
|
||||
void MergeCheckpoint(const MemoryCheckpoint& other);
|
||||
|
||||
private:
|
||||
std::shared_ptr<FunctionLibraryDefinition> flib_def_;
|
||||
FunctionLibraryRuntime* flr_ = nullptr; // not owned
|
||||
|
|
@ -131,6 +131,7 @@ class IteratorResource : public ResourceBase {
|
|||
CancellationManager cancellation_manager_;
|
||||
std::unique_ptr<DatasetBaseIterator> iterator_;
|
||||
core::RefCountPtr<DatasetBase> dataset_;
|
||||
MemoryCheckpoint checkpoint_;
|
||||
};
|
||||
|
||||
IteratorMetricsCollector metrics_collector_;
|
||||
|
|
|
|||
|
|
@ -159,6 +159,8 @@ class MapDatasetOp::Dataset : public DatasetBase {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
|
||||
|
|
@ -166,14 +168,12 @@ class MapDatasetOp::Dataset : public DatasetBase {
|
|||
ctx, &instantiated_captured_func_);
|
||||
}
|
||||
|
||||
// NOTE(mrry): This method is thread-safe as long as `input_impl_` and `f`
|
||||
// are thread-safe. However, if multiple threads enter this method,
|
||||
// outputs may be observed in a non-deterministic order.
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
// NOTE(mrry): This method is thread-safe as long as
|
||||
// `input_impl_` and `f` are thread-safe. However, if multiple
|
||||
// threads enter this method, outputs may be observed in a
|
||||
// non-deterministic order.
|
||||
|
||||
std::vector<Tensor> args;
|
||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &args, end_of_sequence));
|
||||
if (*end_of_sequence) {
|
||||
|
|
|
|||
|
|
@ -187,6 +187,8 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||
}
|
||||
|
|
@ -245,17 +247,21 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase {
|
|||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_)
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kExhausted), static_cast<int64_t>(!input_impl_)));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
else
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kExhausted), ""));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
if (reader->Contains(full_name(kExhausted))) {
|
||||
int64_t input_exhausted;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kExhausted), &input_exhausted));
|
||||
if (static_cast<bool>(input_exhausted)) {
|
||||
input_impl_.reset();
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
|
|
|||
|
|
@ -251,6 +251,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||
if (deregister_fn_) deregister_fn_();
|
||||
}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override {
|
||||
return deterministic_;
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(*mu_);
|
||||
interleave_depth_ = ctx->interleave_depth();
|
||||
|
|
@ -264,8 +268,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
|
||||
IteratorContext::Params params(ctx);
|
||||
params.cancellation_manager = cancellation_manager_.get();
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
IteratorContext(params), this, prefix(), &input_impl_));
|
||||
&iter_ctx, this, prefix(), &input_impl_));
|
||||
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
||||
return dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_);
|
||||
}
|
||||
|
|
@ -310,6 +316,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
if (ctx->symbolic_checkpoint()) {
|
||||
return writer->WriteScalar(
|
||||
full_name(absl::StrCat(kInvocationResults, "::", kSize)), 0);
|
||||
}
|
||||
mutex_lock l(*mu_);
|
||||
// Wait for all in-flight calls to complete.
|
||||
while (num_calls_ > 0) {
|
||||
|
|
@ -320,9 +330,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||
"Unexpected outstanding calls encountered.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(absl::StrCat(prefix(), "::", kInvocationResults),
|
||||
kSize, invocation_results_.size()));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(absl::StrCat(kInvocationResults, "::", kSize)),
|
||||
invocation_results_.size()));
|
||||
for (size_t i = 0; i < invocation_results_.size(); i++) {
|
||||
const auto& result = *(invocation_results_[i]);
|
||||
std::string element_prefix =
|
||||
|
|
@ -336,10 +346,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||
element_prefix, absl::StrCat(kComponent, "[", j, "]"),
|
||||
result.return_values[j]));
|
||||
}
|
||||
if (result.end_of_input) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(element_prefix, kEndOfInput, ""));
|
||||
}
|
||||
writer->WriteScalar(element_prefix, kEndOfInput,
|
||||
static_cast<int64_t>(result.end_of_input)));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -349,9 +358,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||
mutex_lock l(*mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
int64_t invocation_results_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(absl::StrCat(prefix(), "::", kInvocationResults),
|
||||
kSize, &invocation_results_size));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(absl::StrCat(kInvocationResults, "::", kSize)),
|
||||
&invocation_results_size));
|
||||
DCHECK(invocation_results_.empty());
|
||||
for (size_t i = 0; i < invocation_results_size; i++) {
|
||||
invocation_results_.push_back(std::make_shared<InvocationResult>());
|
||||
|
|
@ -378,7 +387,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||
ctx->flr(), element_prefix, absl::StrCat(kComponent, "[", j, "]"),
|
||||
&result.return_values.back()));
|
||||
}
|
||||
result.end_of_input = reader->Contains(element_prefix, kEndOfInput);
|
||||
int64_t end_of_input;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(element_prefix, kEndOfInput, &end_of_input));
|
||||
result.end_of_input = static_cast<bool>(end_of_input);
|
||||
RecordBufferEnqueue(ctx, result.return_values);
|
||||
result.notification.Notify();
|
||||
}
|
||||
|
|
@ -417,6 +429,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||
Status status;
|
||||
std::vector<Tensor> return_values;
|
||||
bool end_of_input = false;
|
||||
MemoryCheckpoint checkpoint;
|
||||
const int64_t uid;
|
||||
};
|
||||
|
||||
|
|
@ -466,6 +479,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||
std::vector<Tensor> input_element;
|
||||
result->status = input_impl_->GetNext(ctx.get(), &input_element,
|
||||
&result->end_of_input);
|
||||
result->checkpoint = ctx->checkpoint();
|
||||
if (result->end_of_input || !result->status.ok()) {
|
||||
CallCompleted(ctx, result);
|
||||
return;
|
||||
|
|
@ -515,6 +529,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
|||
const std::shared_ptr<InvocationResult>& result,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) {
|
||||
ctx->MergeCheckpoint(result->checkpoint);
|
||||
if (!result->end_of_input && result->status.ok()) {
|
||||
*out_tensors = std::move(result->return_values);
|
||||
RecordBufferDequeue(ctx, *out_tensors);
|
||||
|
|
|
|||
|
|
@ -163,6 +163,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||
if (deregister_fn_) deregister_fn_();
|
||||
}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(*mu_);
|
||||
interleave_depth_ = ctx->interleave_depth();
|
||||
|
|
@ -176,8 +178,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||
&deregister_fn_));
|
||||
IteratorContext::Params params(ctx);
|
||||
params.cancellation_manager = cancellation_manager_.get();
|
||||
return dataset()->input_->MakeIterator(IteratorContext(params), this,
|
||||
prefix(), &input_impl_);
|
||||
IteratorContext iter_ctx(params);
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
&iter_ctx, this, prefix(), &input_impl_));
|
||||
ctx->MergeCheckpoint(iter_ctx.checkpoint());
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
|
|
@ -246,13 +251,17 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
if (ctx->symbolic_checkpoint()) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kBufferSize), 0));
|
||||
return OkStatus();
|
||||
}
|
||||
// Acquire both locks to ensure that the prefetch thread and
|
||||
// all GetNext threads are blocked.
|
||||
mutex_lock input_l(input_mu_);
|
||||
mutex_lock l(*mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(prefix(), kBufferSize, buffer_.size()));
|
||||
writer->WriteScalar(full_name(kBufferSize), buffer_.size()));
|
||||
for (size_t i = 0; i < buffer_.size(); i++) {
|
||||
auto& buffer_element = buffer_[i];
|
||||
TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status));
|
||||
|
|
@ -279,7 +288,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||
size_t buffer_size;
|
||||
{
|
||||
int64_t temp;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kBufferSize, &temp));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBufferSize), &temp));
|
||||
buffer_size = static_cast<size_t>(temp);
|
||||
}
|
||||
for (size_t i = 0; i < buffer_size; i++) {
|
||||
|
|
@ -364,6 +373,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||
std::vector<Tensor> value;
|
||||
int64_t created_us;
|
||||
const uint64 uid;
|
||||
MemoryCheckpoint checkpoint;
|
||||
};
|
||||
|
||||
int64_t buffer_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
|
|
@ -422,6 +432,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||
VLOG(2) << "Setting slack_us_: " << slack_us_;
|
||||
}
|
||||
*out_tensors = std::move(buffer_.front().value);
|
||||
ctx->MergeCheckpoint(buffer_.front().checkpoint);
|
||||
RecordBufferDequeue(ctx, *out_tensors);
|
||||
} else {
|
||||
// If status not ok, we still record the dequeue event to make sure each
|
||||
|
|
@ -506,6 +517,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
|||
profiler::kInfo);
|
||||
buffer_element.status = input_impl_->GetNext(
|
||||
ctx.get(), &buffer_element.value, &end_of_sequence);
|
||||
buffer_element.checkpoint = ctx->checkpoint();
|
||||
}
|
||||
if (buffer_element.status.ok() && end_of_sequence) {
|
||||
mutex_lock l(*mu_);
|
||||
|
|
|
|||
|
|
@ -255,6 +255,8 @@ class RangeDatasetOp::Dataset : public DatasetBase {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
if (ctx->split_providers().empty() || dataset()->replicate_on_split_) {
|
||||
counter_ = std::make_unique<RangeCounter>(
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/data/name_utils.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
|
|
@ -140,6 +141,9 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
|||
public:
|
||||
explicit EmptyIterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
|
|
@ -169,6 +173,8 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
|||
explicit FiniteIterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params), i_(0) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||
}
|
||||
|
|
@ -210,9 +216,9 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
|||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIteration), i_));
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kInputImplEmpty), static_cast<int64_t>(!input_impl_)));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
}
|
||||
return OkStatus();
|
||||
|
|
@ -222,7 +228,10 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
|||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIteration), &i_));
|
||||
if (!reader->Contains(full_name(kInputImplEmpty))) {
|
||||
int64_t input_empty;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kInputImplEmpty), &input_empty));
|
||||
if (static_cast<bool>(!input_empty)) {
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
} else {
|
||||
input_impl_.reset();
|
||||
|
|
@ -243,6 +252,8 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
|||
input_impl_(nullptr),
|
||||
first_call_(true) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||
|
|
@ -262,9 +273,8 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
|||
DCHECK(!*end_of_sequence || out_tensors->empty());
|
||||
if (first_call_ && *end_of_sequence && ctx->split_providers().empty()) {
|
||||
// If the first call to GetNext() fails because the end of sequence
|
||||
// has been reached, we terminate the iteration immediately.
|
||||
// Otherwise, this iterator would loop infinitely and never produce a
|
||||
// value.
|
||||
// has been reached, we return EOF. Otherwise, this iterator could
|
||||
// loop infinitely and never produce a value.
|
||||
input_impl_.reset();
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -290,17 +300,21 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
|||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!first_call_)
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kInputImplEmpty), static_cast<int64_t>(!input_impl_)));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
else
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kUninitialized), ""));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
if (reader->Contains(full_name(kUninitialized))) {
|
||||
int64_t input_empty;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kInputImplEmpty), &input_empty));
|
||||
if (static_cast<bool>(input_empty)) {
|
||||
input_impl_.reset();
|
||||
first_call_ = true;
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -135,6 +135,8 @@ class ShardDatasetOp::Dataset : public DatasetBase {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params), next_index_(0) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
if (dataset()->num_shards_ == kShardHint) {
|
||||
return errors::FailedPrecondition(
|
||||
|
|
@ -221,9 +223,9 @@ class ShardDatasetOp::Dataset : public DatasetBase {
|
|||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kInputImplEmpty), static_cast<int64_t>(!input_impl_)));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kNextIndex), next_index_));
|
||||
|
|
@ -234,7 +236,10 @@ class ShardDatasetOp::Dataset : public DatasetBase {
|
|||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!reader->Contains(full_name(kInputImplEmpty))) {
|
||||
int64_t input_empty;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kInputImplEmpty), &input_empty));
|
||||
if (!static_cast<bool>(input_empty)) {
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kNextIndex), &next_index_));
|
||||
|
|
|
|||
|
|
@ -115,6 +115,9 @@ class SkipDatasetOp::Dataset : public DatasetBase {
|
|||
public:
|
||||
explicit EmptyIterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
|
|
@ -145,6 +148,8 @@ class SkipDatasetOp::Dataset : public DatasetBase {
|
|||
explicit FiniteIterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params), i_(0) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||
}
|
||||
|
|
@ -191,10 +196,10 @@ class SkipDatasetOp::Dataset : public DatasetBase {
|
|||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kInputImplEmpty), static_cast<int64_t>(!input_impl_)));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -203,7 +208,10 @@ class SkipDatasetOp::Dataset : public DatasetBase {
|
|||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &i_));
|
||||
if (!reader->Contains(full_name(kInputImplEmpty))) {
|
||||
int64_t input_empty;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kInputImplEmpty), &input_empty));
|
||||
if (!static_cast<bool>(input_empty)) {
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
} else {
|
||||
input_impl_.reset();
|
||||
|
|
|
|||
|
|
@ -108,6 +108,9 @@ class TakeDataset::EmptyIterator : public DatasetIterator<TakeDataset> {
|
|||
public:
|
||||
explicit EmptyIterator(const Params& params)
|
||||
: DatasetIterator<TakeDataset>(params) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
*end_of_sequence = true;
|
||||
|
|
@ -137,6 +140,8 @@ class TakeDataset::FiniteIterator : public DatasetIterator<TakeDataset> {
|
|||
explicit FiniteIterator(const Params& params)
|
||||
: DatasetIterator<TakeDataset>(params), i_(0) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||
}
|
||||
|
|
@ -173,10 +178,10 @@ class TakeDataset::FiniteIterator : public DatasetIterator<TakeDataset> {
|
|||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty),
|
||||
static_cast<int64_t>(!input_impl_)));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -185,7 +190,10 @@ class TakeDataset::FiniteIterator : public DatasetIterator<TakeDataset> {
|
|||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &i_));
|
||||
if (!reader->Contains(full_name(kInputImplEmpty))) {
|
||||
int64_t input_empty;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kInputImplEmpty), &input_empty));
|
||||
if (!static_cast<bool>(input_empty)) {
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
} else {
|
||||
input_impl_.reset();
|
||||
|
|
|
|||
|
|
@ -121,6 +121,8 @@ class TensorDatasetOp::Dataset : public DatasetBase {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params), produced_(false) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
if (!ctx->split_providers().empty()) {
|
||||
TF_ASSIGN_OR_RETURN(split_provider_,
|
||||
|
|
@ -161,15 +163,17 @@ class TensorDatasetOp::Dataset : public DatasetBase {
|
|||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (produced_)
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kProduced), ""));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kProduced),
|
||||
static_cast<int64_t>(produced_)));
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
produced_ = reader->Contains(full_name(kProduced));
|
||||
int64_t produced;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kProduced), &produced));
|
||||
produced_ = static_cast<bool>(produced);
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -151,6 +151,8 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
if (ctx->split_providers().empty() || dataset()->replicate_on_split_) {
|
||||
split_provider_ = std::make_shared<IndexSplitProvider>(
|
||||
|
|
|
|||
|
|
@ -164,6 +164,8 @@ class ZipDatasetOp::Dataset : public DatasetBase {
|
|||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
bool SymbolicCheckpointCompatible() const override { return true; }
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_ASSIGN_OR_RETURN(input_contexts_,
|
||||
|
|
@ -173,6 +175,7 @@ class ZipDatasetOp::Dataset : public DatasetBase {
|
|||
TF_RETURN_IF_ERROR(dataset()->inputs_[i]->MakeIterator(
|
||||
&input_contexts_[i], this, strings::StrCat(prefix(), "[", i, "]"),
|
||||
&input_impls_[i]));
|
||||
ctx->MergeCheckpoint(input_contexts_[i].checkpoint());
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
|
@ -195,6 +198,7 @@ class ZipDatasetOp::Dataset : public DatasetBase {
|
|||
bool component_end_of_sequence = false;
|
||||
status.Update(input_impl->GetNext(&input_contexts_[i], &input_tensors,
|
||||
&component_end_of_sequence));
|
||||
ctx->MergeCheckpoint(input_contexts_[i].checkpoint());
|
||||
*end_of_sequence |= component_end_of_sequence;
|
||||
// Even if an error is encountered for one of the components,
|
||||
// we need to make sure to advance all components, to keep them in sync.
|
||||
|
|
@ -209,6 +213,7 @@ class ZipDatasetOp::Dataset : public DatasetBase {
|
|||
Status s =
|
||||
input_impls_[j]->GetNext(&input_contexts_[j], &input_tensors,
|
||||
&component_end_of_sequence);
|
||||
ctx->MergeCheckpoint(input_contexts_[j].checkpoint());
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -236,11 +241,10 @@ class ZipDatasetOp::Dataset : public DatasetBase {
|
|||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (input_impls_.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kInputImplsEmpty), ""));
|
||||
} else {
|
||||
for (auto& input_impl : input_impls_)
|
||||
writer->WriteScalar(full_name(kInputImplsEmpty),
|
||||
static_cast<int64_t>(input_impls_.empty())));
|
||||
for (auto& input_impl : input_impls_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl));
|
||||
}
|
||||
return OkStatus();
|
||||
|
|
@ -249,7 +253,10 @@ class ZipDatasetOp::Dataset : public DatasetBase {
|
|||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
if (reader->Contains(full_name(kInputImplsEmpty))) {
|
||||
int64_t inputs_empty;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kInputImplsEmpty), &inputs_empty));
|
||||
if (static_cast<bool>(inputs_empty)) {
|
||||
input_impls_.clear();
|
||||
} else {
|
||||
DCHECK_EQ(input_impls_.size(), dataset()->inputs_.size());
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ tf_py_test(
|
|||
"//tensorflow/python/data/kernel_tests:checkpoint_test_base",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
|
@ -162,9 +163,11 @@ tf_py_test(
|
|||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:stateless_random_ops",
|
||||
"//tensorflow/python/data/kernel_tests:checkpoint_test_base",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"//tensorflow/python/framework:combinations",
|
||||
"//tensorflow/python/framework:constant_op",
|
||||
"//tensorflow/python/framework:dtypes",
|
||||
|
|
@ -189,6 +192,7 @@ tf_py_test(
|
|||
"//tensorflow/python/data/kernel_tests:checkpoint_test_base",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"//tensorflow/python/ops/ragged",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from tensorflow.python.data.experimental.ops import cardinality
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
|
|
@ -84,16 +85,22 @@ class AssertCardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class AssertCardinalityCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def test(self, verify_fn):
|
||||
|
||||
def build_dataset(num_elements):
|
||||
return dataset_ops.Dataset.range(num_elements).apply(
|
||||
def build_dataset(self, num_elements, options=None):
|
||||
dataset = dataset_ops.Dataset.range(num_elements).apply(
|
||||
cardinality.assert_cardinality(num_elements))
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
verify_fn(self, lambda: build_dataset(200), num_outputs=200)
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(self, lambda: self.build_dataset(200, options), num_outputs=200)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -19,12 +19,14 @@ import numpy as np
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import stateless_random_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
|
|
@ -330,24 +332,67 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase,
|
|||
dataset_ops.Dataset.choose_from_datasets(datasets, choice_dataset=None)
|
||||
|
||||
|
||||
class ChooseFromDatasetsCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_dataset(self,
|
||||
num_datasets,
|
||||
num_elements_per_dataset,
|
||||
options=None):
|
||||
datasets = [
|
||||
dataset_ops.Dataset.range(num_elements_per_dataset)
|
||||
for _ in range(num_datasets)
|
||||
]
|
||||
indices = []
|
||||
for i in range(num_datasets):
|
||||
indices = indices + ([i] * num_elements_per_dataset)
|
||||
shuffled_indices = stateless_random_ops.stateless_shuffle(
|
||||
np.int64(indices), seed=[1, 2])
|
||||
choice_dataset = dataset_ops.Dataset.from_tensor_slices(shuffled_indices)
|
||||
dataset = dataset_ops.Dataset.choose_from_datasets(datasets, choice_dataset)
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(
|
||||
self, lambda: self._build_dataset(5, 20, options), num_outputs=100)
|
||||
|
||||
|
||||
class SampleFromDatasetsCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_dataset(self, probs, num_samples):
|
||||
def _build_dataset(self, probs, num_samples, options=None):
|
||||
datasets = [
|
||||
dataset_ops.Dataset.from_tensors(i).repeat(None)
|
||||
for i in range(len(probs))
|
||||
]
|
||||
dataset = dataset_ops.Dataset.sample_from_datasets(
|
||||
datasets, probs, seed=1813)
|
||||
return dataset.take(num_samples)
|
||||
dataset = dataset.take(num_samples)
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def test(self, verify_fn):
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(
|
||||
self, lambda: self._build_dataset([0.5, 0.5], 100), num_outputs=100)
|
||||
self,
|
||||
lambda: self._build_dataset([0.5, 0.5], 100, options),
|
||||
num_outputs=100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from tensorflow.python.data.experimental.ops import random_access
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
|
|
@ -182,20 +183,30 @@ class FromListRandomAccessTest(test_base.DatasetTestBase,
|
|||
class FromListCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_list_dataset(self, elements):
|
||||
return from_list.from_list(elements)
|
||||
def _build_list_dataset(self, elements, options=None):
|
||||
dataset = from_list.from_list(elements)
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def test(self, verify_fn):
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
# Equal length elements
|
||||
elements = [
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||
np.tile(np.array([[12], [13], [14], [15]]), 22),
|
||||
np.array([37, 38, 39, 40])
|
||||
]
|
||||
verify_fn(self, lambda: self._build_list_dataset(elements), num_outputs=3)
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(
|
||||
self,
|
||||
lambda: self._build_list_dataset(elements, options),
|
||||
num_outputs=3)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
|
|
@ -214,5 +225,6 @@ class FromListCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
|||
verify_fn(
|
||||
self, lambda: self._build_list_dataset(dict_elements), num_outputs=3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ tf_py_test(
|
|||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"//tensorflow/python/ops/ragged",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
|
|
@ -159,6 +160,7 @@ tf_py_test(
|
|||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
|
|
@ -169,9 +171,12 @@ tf_py_test(
|
|||
size = "small",
|
||||
srcs = ["counter_test.py"],
|
||||
deps = [
|
||||
":checkpoint_test_base",
|
||||
":test_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -216,6 +221,7 @@ tf_py_test(
|
|||
size = "small",
|
||||
srcs = ["enumerate_test.py"],
|
||||
deps = [
|
||||
":checkpoint_test_base",
|
||||
":test_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
|
|
@ -223,6 +229,7 @@ tf_py_test(
|
|||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -295,6 +302,7 @@ tf_py_test(
|
|||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/ops/ragged",
|
||||
|
|
@ -360,6 +368,7 @@ tf_py_test(
|
|||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/experimental/ops:random_access",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"//tensorflow/python/ops/ragged",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
|
|
@ -713,6 +722,7 @@ tf_py_test(
|
|||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
|
|
@ -750,6 +760,8 @@ tf_py_test(
|
|||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
|
@ -813,6 +825,7 @@ tf_py_test(
|
|||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -861,6 +874,7 @@ tf_py_test(
|
|||
name = "repeat_test",
|
||||
size = "medium",
|
||||
srcs = ["repeat_test.py"],
|
||||
shard_count = 2,
|
||||
deps = [
|
||||
":checkpoint_test_base",
|
||||
":test_base",
|
||||
|
|
@ -908,6 +922,7 @@ tf_py_test(
|
|||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -947,6 +962,7 @@ tf_py_test(
|
|||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
|
@ -990,6 +1006,7 @@ tf_py_test(
|
|||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
|
@ -1010,7 +1027,7 @@ tf_py_test(
|
|||
"//tensorflow/python/data/kernel_tests:checkpoint_test_base",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/data/ops:options",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -302,23 +302,35 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class BatchCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
|
||||
def _build_dataset(self,
|
||||
multiplier=15.0,
|
||||
tensor_slice_len=2,
|
||||
batch_size=2,
|
||||
options=None):
|
||||
components = (np.arange(tensor_slice_len), np.array([[1, 2, 3]]) *
|
||||
np.arange(tensor_slice_len)[:, np.newaxis],
|
||||
np.array(multiplier) * np.arange(tensor_slice_len))
|
||||
|
||||
return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).batch(
|
||||
batch_size)
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def test(self, verify_fn):
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
tensor_slice_len = 8
|
||||
batch_size = 2
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
num_outputs = tensor_slice_len // batch_size
|
||||
verify_fn(self,
|
||||
lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
|
||||
num_outputs)
|
||||
verify_fn(
|
||||
self, lambda: self._build_dataset(15.0, tensor_slice_len, batch_size,
|
||||
options), num_outputs)
|
||||
|
||||
def _sparse(self, i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from tensorflow.python.data.experimental.ops import random_access
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
|
|
@ -162,22 +163,31 @@ class ConcatenateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class ConcatenateCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_concatenate_dataset(self, var_array):
|
||||
def _build_concatenate_dataset(self, var_array, options=None):
|
||||
input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||
np.tile(np.array([[12], [13], [14], [15]]), 4))
|
||||
to_concatenate_components = (np.tile(
|
||||
np.array([[5], [6], [7], [8], [9]]), 20), var_array)
|
||||
|
||||
return dataset_ops.Dataset.from_tensor_slices(input_components).concatenate(
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
input_components).concatenate(
|
||||
dataset_ops.Dataset.from_tensor_slices(to_concatenate_components))
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def test(self, verify_fn):
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
num_outputs = 9
|
||||
array = np.tile(np.array([[16], [17], [18], [19], [20]]), 15)
|
||||
verify_fn(self, lambda: self._build_concatenate_dataset(array), num_outputs)
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(self, lambda: self._build_concatenate_dataset(array, options),
|
||||
num_outputs)
|
||||
|
||||
|
||||
class ConcatenateRandomAccessTest(test_base.DatasetTestBase,
|
||||
|
|
|
|||
|
|
@ -15,8 +15,10 @@
|
|||
"""Tests for `tf.data.Dataset.counter`."""
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.platform import test
|
||||
|
|
@ -40,5 +42,31 @@ class CounterTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
self.assertEqual(expected, self.evaluate(get_next()))
|
||||
|
||||
|
||||
class CounterCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_counter_dataset(self, start, step, num_outputs, options=None):
|
||||
counter_dataset = dataset_ops.Dataset.counter(start, step)
|
||||
range_dataset = dataset_ops.Dataset.range(num_outputs)
|
||||
dataset = dataset_ops.Dataset.zip((counter_dataset, range_dataset))
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
num_outputs = 10
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(
|
||||
self, lambda: self._build_counter_dataset(
|
||||
start=2, step=10, num_outputs=num_outputs, options=options),
|
||||
num_outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -15,8 +15,10 @@
|
|||
"""Tests for `tf.data.Dataset.enumerate()`."""
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
|
@ -45,5 +47,29 @@ class EnumerateTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
(21, (b"b", 2, 38.0))])
|
||||
|
||||
|
||||
class EnumerateCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_enumerate_dataset(self, start, stop, options=None):
|
||||
dataset = dataset_ops.Dataset.range(start, stop).enumerate()
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
start = 2
|
||||
stop = 10
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(
|
||||
self, lambda: self._build_enumerate_dataset(
|
||||
start=start, stop=stop, options=options), stop - start)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
|
|
@ -165,19 +166,27 @@ class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class FilterCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_filter_range_graph(self, div):
|
||||
return dataset_ops.Dataset.range(100).filter(
|
||||
def _build_filter_range_dataset(self, div, options=None):
|
||||
dataset = dataset_ops.Dataset.range(100).filter(
|
||||
lambda x: math_ops.not_equal(math_ops.mod(x, div), 2))
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def test(self, verify_fn):
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
div = 3
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
num_outputs = sum(x % 3 != 2 for x in range(100))
|
||||
verify_fn(self, lambda: self._build_filter_range_graph(div), num_outputs)
|
||||
verify_fn(self, lambda: self._build_filter_range_dataset(div, options),
|
||||
num_outputs)
|
||||
|
||||
def _build_filter_dict_graph(self):
|
||||
def _build_filter_dict_dataset(self):
|
||||
return dataset_ops.Dataset.range(10).map(lambda x: {
|
||||
"foo": x * 2,
|
||||
"bar": x**2
|
||||
|
|
@ -189,9 +198,9 @@ class FilterCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
|||
checkpoint_test_base.default_test_combinations()))
|
||||
def testDict(self, verify_fn):
|
||||
num_outputs = sum((x**2) % 2 == 0 for x in range(10))
|
||||
verify_fn(self, self._build_filter_dict_graph, num_outputs)
|
||||
verify_fn(self, self._build_filter_dict_dataset, num_outputs)
|
||||
|
||||
def _build_sparse_filter(self):
|
||||
def _build_sparse_filter_dataset(self):
|
||||
|
||||
def _map_fn(i):
|
||||
return sparse_tensor.SparseTensor(
|
||||
|
|
@ -207,7 +216,7 @@ class FilterCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
|||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def testSparse(self, verify_fn):
|
||||
verify_fn(self, self._build_sparse_filter, num_outputs=5)
|
||||
verify_fn(self, self._build_sparse_filter_dataset, num_outputs=5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from tensorflow.python.client import session
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
|
@ -183,16 +184,22 @@ class FlatMapCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
|||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def test(self, verify_fn):
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
# Complicated way of saying range(start, start+25).
|
||||
def build_ds(start):
|
||||
|
||||
def map_fn(x):
|
||||
return dataset_ops.Dataset.range(x, x + 5)
|
||||
|
||||
return dataset_ops.Dataset.range(start, start + 5 * 5, 5).flat_map(map_fn)
|
||||
dataset = dataset_ops.Dataset.range(start, start + 5 * 5, 5)
|
||||
dataset = dataset.flat_map(map_fn)
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
return dataset.with_options(options)
|
||||
|
||||
verify_fn(self, lambda: build_ds(0), num_outputs=25)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from tensorflow.python.data.experimental.ops import random_access
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
|
|
@ -368,21 +369,27 @@ class FromTensorSlicesRandomAccessTest(test_base.DatasetTestBase,
|
|||
class FromTensorSlicesCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_tensor_slices_dataset(self, components):
|
||||
return dataset_ops.Dataset.from_tensor_slices(components)
|
||||
def _build_tensor_slices_dataset(self, components, options=None):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components)
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def test(self, verify_fn):
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
# Equal length components
|
||||
components = (np.tile(np.array([[1], [2], [3], [4]]),
|
||||
20), np.tile(np.array([[12], [13], [14], [15]]),
|
||||
22), np.array([37.0, 38.0, 39.0, 40.0]))
|
||||
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(
|
||||
self,
|
||||
lambda: self._build_tensor_slices_dataset(components),
|
||||
lambda: self._build_tensor_slices_dataset(components, options),
|
||||
num_outputs=4)
|
||||
|
||||
@combinations.generate(
|
||||
|
|
|
|||
|
|
@ -302,17 +302,24 @@ class FromTensorsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class FromTensorsCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_tensor_dataset(self, variable_array):
|
||||
def _build_tensor_dataset(self, variable_array, options=None):
|
||||
components = (variable_array, np.array([1, 2, 3]), np.array(37.0))
|
||||
|
||||
return dataset_ops.Dataset.from_tensors(components)
|
||||
dataset = dataset_ops.Dataset.from_tensors(components)
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def test(self, verify_fn):
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
arr = np.array(1)
|
||||
verify_fn(self, lambda: self._build_tensor_dataset(arr), num_outputs=1)
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(
|
||||
self, lambda: self._build_tensor_dataset(arr, options), num_outputs=1)
|
||||
|
||||
|
||||
class FromTensorsRandomAccessTest(test_base.DatasetTestBase,
|
||||
|
|
|
|||
|
|
@ -393,19 +393,27 @@ class InterleaveDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
|||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
symbolic_checkpoint=[False, True],
|
||||
cycle_length=2,
|
||||
block_length=[1, 3],
|
||||
num_parallel_calls=[None, 1, 2])))
|
||||
def test(self, verify_fn, cycle_length, block_length, num_parallel_calls):
|
||||
def test(self, verify_fn, symbolic_checkpoint, cycle_length, block_length,
|
||||
num_parallel_calls):
|
||||
|
||||
num_repeats = 2
|
||||
input_values = np.array([2, 3], dtype=np.int64)
|
||||
|
||||
def _build_dataset():
|
||||
return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
|
||||
num_repeats).interleave(
|
||||
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
|
||||
cycle_length, block_length, num_parallel_calls)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(input_values)
|
||||
dataset = dataset.repeat(num_repeats)
|
||||
dataset = dataset.interleave(
|
||||
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), cycle_length,
|
||||
block_length, num_parallel_calls)
|
||||
if num_parallel_calls is None:
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
num_outputs = np.sum(input_values) * num_repeats
|
||||
verify_fn(self, _build_dataset, num_outputs)
|
||||
|
|
|
|||
|
|
@ -1347,10 +1347,12 @@ class MapCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
|||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(num_parallel_calls=[None, 2])))
|
||||
def testCore(self, verify_fn, num_parallel_calls):
|
||||
combinations.combine(
|
||||
num_parallel_calls=[None, 2], symbolic_checkpoint=[False, True])))
|
||||
def testCore(self, verify_fn, num_parallel_calls, symbolic_checkpoint):
|
||||
|
||||
tensor_slice_len = 7
|
||||
num_epochs = 2
|
||||
|
|
@ -1365,8 +1367,11 @@ class MapCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
|||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
return (dataset_ops.Dataset.from_tensor_slices(components).map(
|
||||
_map_fn, num_parallel_calls=num_parallel_calls).repeat(num_epochs))
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
|
||||
_map_fn, num_parallel_calls=num_parallel_calls).repeat(num_epochs)
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
return dataset.with_options(options)
|
||||
|
||||
verify_fn(self, _build_ds, tensor_slice_len * num_epochs)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import numpy as np
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
|
@ -349,14 +350,20 @@ class PaddedBatchCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
|||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def test(self, verify_fn):
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
|
||||
def build_dataset(seq_lens):
|
||||
return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
|
||||
lambda x: array_ops.fill([x], x)).padded_batch(
|
||||
batch_size=4, padded_shapes=[-1])
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
seq_lens = np.random.randint(1, 20, size=(32,)).astype(np.int32)
|
||||
verify_fn(self, lambda: build_dataset(seq_lens), num_outputs=8)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from tensorflow.python.data.experimental.ops import random_access
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
|
|
@ -82,15 +83,21 @@ class PrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class PrefetchCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def build_dataset(self, seed=10):
|
||||
return dataset_ops.Dataset.range(100).prefetch(10).shuffle(
|
||||
buffer_size=10, seed=seed, reshuffle_each_iteration=False)
|
||||
def build_dataset(self, options=None):
|
||||
dataset = dataset_ops.Dataset.range(100).prefetch(10)
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def test(self, verify_fn):
|
||||
verify_fn(self, self.build_dataset, num_outputs=100)
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(self, lambda: self.build_dataset(options), num_outputs=100)
|
||||
|
||||
|
||||
class PrefetchRandomAccessTest(test_base.DatasetTestBase,
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from tensorflow.python.data.experimental.ops import random_access
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
|
|
@ -149,16 +150,23 @@ class RangeTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class RangeCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_range_dataset(self, start, stop):
|
||||
return dataset_ops.Dataset.range(start, stop)
|
||||
def _build_range_dataset(self, start, stop, options=None):
|
||||
dataset = dataset_ops.Dataset.range(start, stop)
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def test(self, verify_fn):
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, symbolic_checkpoint):
|
||||
start = 2
|
||||
stop = 10
|
||||
verify_fn(self, lambda: self._build_range_dataset(start, stop),
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(self, lambda: self._build_range_dataset(start, stop, options),
|
||||
stop - start)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from tensorflow.python.data.experimental.ops import random_access
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
|
|
@ -77,41 +78,83 @@ class RepeatTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class RepeatDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_repeat_dataset(self, count, take_count=3):
|
||||
components = (np.arange(10),)
|
||||
return dataset_ops.Dataset.from_tensor_slices(components).take(
|
||||
take_count).repeat(count)
|
||||
def _build_repeat_dataset(self,
|
||||
num_elements,
|
||||
num_epochs,
|
||||
num_outputs=None,
|
||||
options=None):
|
||||
dataset = dataset_ops.Dataset.range(num_elements).repeat(num_epochs)
|
||||
if num_outputs:
|
||||
range_dataset = dataset_ops.Dataset.range(num_outputs)
|
||||
dataset = dataset_ops.Dataset.zip((dataset, range_dataset))
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def testFiniteRepeat(self, verify_fn):
|
||||
count = 10
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def testFiniteRepeat(self, verify_fn, symbolic_checkpoint):
|
||||
num_elements = 10
|
||||
num_epochs = 10
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(
|
||||
self,
|
||||
lambda: self._build_repeat_dataset(count),
|
||||
num_outputs=(3 * count))
|
||||
lambda: self._build_repeat_dataset(
|
||||
num_elements, num_epochs, options=options),
|
||||
num_outputs=(num_elements * num_epochs))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def testEmptyRepeat(self, verify_fn):
|
||||
verify_fn(self, lambda: self._build_repeat_dataset(0), num_outputs=0)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInfiniteRepeat(self):
|
||||
self.verify_unused_iterator(
|
||||
lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
|
||||
self.verify_multiple_breaks(
|
||||
lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
|
||||
self.verify_reset_restored_iterator(
|
||||
lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def testEmptyRepeat(self, verify_fn, symbolic_checkpoint):
|
||||
num_elements = 10
|
||||
num_epochs = 0
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(
|
||||
self,
|
||||
lambda: self._build_repeat_dataset(
|
||||
num_elements, num_epochs, options=options),
|
||||
num_outputs=0)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations()))
|
||||
def testInfiniteEmptyRepeat(self, verify_fn):
|
||||
verify_fn(self, lambda: self._build_repeat_dataset(-1, 0), num_outputs=0)
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def testInfiniteRepeat(self, verify_fn, symbolic_checkpoint):
|
||||
num_elements = 10
|
||||
num_epochs = -1
|
||||
num_outputs = 100
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(
|
||||
self,
|
||||
lambda: self._build_repeat_dataset(
|
||||
num_elements, num_epochs, num_outputs=num_outputs, options=options),
|
||||
num_outputs=num_outputs)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def testInfiniteEmptyRepeat(self, verify_fn, symbolic_checkpoint):
|
||||
num_elements = 0
|
||||
num_epochs = -1
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(
|
||||
self,
|
||||
lambda: self._build_repeat_dataset(
|
||||
num_elements, num_epochs, options=options),
|
||||
num_outputs=0)
|
||||
|
||||
|
||||
class RepeatRandomAccessTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from tensorflow.python.data.experimental.ops import random_access
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
|
|
@ -101,19 +102,25 @@ class ShardTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class ShardCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_dataset(self, num_elements, num_shards, index):
|
||||
return dataset_ops.Dataset.range(num_elements).shard(num_shards, index)
|
||||
def _build_dataset(self, num_elements, num_shards, index, options=None):
|
||||
dataset = dataset_ops.Dataset.range(num_elements).shard(num_shards, index)
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True]),
|
||||
combinations.combine(
|
||||
elems=[10, 100], num_shards=[2, 5], index=[0, 1])))
|
||||
def test(self, verify_fn, elems, num_shards, index):
|
||||
def test(self, verify_fn, symbolic_checkpoint, elems, num_shards, index):
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(
|
||||
self,
|
||||
lambda: self._build_dataset(elems, num_shards, index),
|
||||
lambda: self._build_dataset(elems, num_shards, index, options),
|
||||
num_outputs=elems // num_shards)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from tensorflow.python.data.experimental.ops import random_access
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
|
|
@ -50,19 +51,25 @@ class SkipTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class SkipDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_skip_dataset(self, count):
|
||||
components = (np.arange(10),)
|
||||
return dataset_ops.Dataset.from_tensor_slices(components).skip(count)
|
||||
def _build_skip_dataset(self, count, options=None):
|
||||
dataset = dataset_ops.Dataset.range(100).skip(count)
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(count=[5], num_outputs=[5]) +
|
||||
combinations.combine(count=[20, 10, -1], num_outputs=[0]) +
|
||||
combinations.combine(count=[0], num_outputs=[10])))
|
||||
def test(self, verify_fn, count, num_outputs):
|
||||
verify_fn(self, lambda: self._build_skip_dataset(count), num_outputs)
|
||||
combinations.combine(symbolic_checkpoint=[False, True]),
|
||||
combinations.combine(count=[50], num_outputs=[50]) +
|
||||
combinations.combine(count=[200, 100, -1], num_outputs=[0]) +
|
||||
combinations.combine(count=[0], num_outputs=[100])))
|
||||
def test(self, verify_fn, count, num_outputs, symbolic_checkpoint):
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(self, lambda: self._build_skip_dataset(count, options),
|
||||
num_outputs)
|
||||
|
||||
|
||||
class SkipRandomAccessTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from tensorflow.python.data.experimental.ops import random_access
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
|
|
@ -49,19 +50,25 @@ class TakeTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class TakeDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_take_dataset(self, count):
|
||||
components = (np.arange(10),)
|
||||
return dataset_ops.Dataset.from_tensor_slices(components).take(count)
|
||||
def _build_take_dataset(self, count, options=None):
|
||||
dataset = dataset_ops.Dataset.range(100).take(count)
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(count=[5], num_outputs=[5]) +
|
||||
combinations.combine(count=[20, 10, -1], num_outputs=[10]) +
|
||||
combinations.combine(symbolic_checkpoint=[False, True]),
|
||||
combinations.combine(count=[50], num_outputs=[50]) +
|
||||
combinations.combine(count=[200, 100, -1], num_outputs=[100]) +
|
||||
combinations.combine(count=[0], num_outputs=[0])))
|
||||
def test(self, verify_fn, count, num_outputs):
|
||||
verify_fn(self, lambda: self._build_take_dataset(count), num_outputs)
|
||||
def test(self, verify_fn, symbolic_checkpoint, count, num_outputs):
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(self, lambda: self._build_take_dataset(count, options),
|
||||
num_outputs)
|
||||
|
||||
|
||||
class TakeRandomAccessTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
|
|
@ -127,17 +128,24 @@ class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class TakeWhileCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_dataset(self, num_elements, upper_bound):
|
||||
return dataset_ops.Dataset.range(num_elements).take_while(
|
||||
predicate=lambda x: x < upper_bound)
|
||||
def _build_dataset(self, num_elements, upper_bound, options=None):
|
||||
dataset = dataset_ops.Dataset.range(num_elements)
|
||||
dataset = dataset.take_while(predicate=lambda x: x < upper_bound)
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(symbolic_checkpoint=[False, True]),
|
||||
combinations.combine(num_elements=[10, 23], upper_bound=[10, 23])))
|
||||
def test(self, verify_fn, num_elements, upper_bound):
|
||||
verify_fn(self, lambda: self._build_dataset(num_elements, upper_bound),
|
||||
def test(self, verify_fn, symbolic_checkpoint, num_elements, upper_bound):
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(self,
|
||||
lambda: self._build_dataset(num_elements, upper_bound, options),
|
||||
min(num_elements, upper_bound))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from tensorflow.python.data.experimental.ops import random_access
|
|||
from tensorflow.python.data.kernel_tests import checkpoint_test_base
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import options as options_lib
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
|
@ -137,7 +138,7 @@ class ZipTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
class ZipCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _build_dataset(self, arr):
|
||||
def _build_dataset(self, arr, options=None):
|
||||
components = [
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||
np.tile(np.array([[12], [13], [14], [15]]), 22),
|
||||
|
|
@ -147,16 +148,22 @@ class ZipCheckpointTest(checkpoint_test_base.CheckpointTestBase,
|
|||
dataset_ops.Dataset.from_tensor_slices(component)
|
||||
for component in components
|
||||
]
|
||||
return dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
|
||||
dataset = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
|
||||
if options:
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
checkpoint_test_base.default_test_combinations(),
|
||||
combinations.combine(elements=[[37.0, 38.0, 39.0, 40.0], [1.0, 2.0]]))
|
||||
)
|
||||
def test(self, verify_fn, elements):
|
||||
verify_fn(self, lambda: self._build_dataset(elements), len(elements))
|
||||
combinations.combine(elements=[[37.0, 38.0, 39.0, 40.0], [1.0, 2.0]]),
|
||||
combinations.combine(symbolic_checkpoint=[False, True])))
|
||||
def test(self, verify_fn, elements, symbolic_checkpoint):
|
||||
options = options_lib.Options()
|
||||
options.experimental_symbolic_checkpoint = symbolic_checkpoint
|
||||
verify_fn(self, lambda: self._build_dataset(elements, options),
|
||||
len(elements))
|
||||
|
||||
|
||||
class ZipRandomAccessTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@ py_library(
|
|||
py_library(
|
||||
name = "counter_op",
|
||||
srcs = ["counter_op.py"],
|
||||
deps = [
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
|
|
|||
|
|
@ -14,13 +14,17 @@
|
|||
# ==============================================================================
|
||||
"""The implementation of `tf.data.Dataset.counter`."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
def counter(start, step, dtype, name=None):
|
||||
with ops.name_scope("counter"):
|
||||
start = ops.convert_to_tensor(start, dtype=dtype, name="start")
|
||||
step = ops.convert_to_tensor(step, dtype=dtype, name="step")
|
||||
return (dataset_ops.Dataset.from_tensors(0, name=name).repeat(None).scan(
|
||||
start, lambda state, _: (state + step, state)))
|
||||
min_value = np.iinfo(dtypes.int64.as_numpy_dtype).min
|
||||
max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
|
||||
stop = max_value if step >= 0 else min_value
|
||||
return dataset_ops.Dataset.range(
|
||||
start, stop, step, output_type=dtype, name=name)
|
||||
|
|
|
|||
|
|
@ -5130,7 +5130,7 @@ class RangeDataset(DatasetSource):
|
|||
self._stop = self._build_tensor(args[1], "stop")
|
||||
self._step = self._build_tensor(args[2], "step")
|
||||
else:
|
||||
raise ValueError(f"Invalid `args`. The lenght of `args` should be "
|
||||
raise ValueError(f"Invalid `args`. The length of `args` should be "
|
||||
f"between 1 and 3 but was {len(args)}.")
|
||||
if "output_type" in kwargs:
|
||||
self._output_type = kwargs["output_type"]
|
||||
|
|
|
|||
|
|
@ -571,6 +571,17 @@ class Options(options_lib.OptionsBase):
|
|||
"frequency is determined by the number of devices attached to this "
|
||||
"input pipeline. If None, defaults to False.")
|
||||
|
||||
experimental_symbolic_checkpoint = options_lib.create_option(
|
||||
name="experimental_symbolic_checkpoint",
|
||||
ty=bool,
|
||||
docstring="Whether to checkpoint internal input pipeline state "
|
||||
"maintaining cursors into data sources that identify last "
|
||||
"element(s) produced as output to the tf.data consumer. This "
|
||||
"is alternative to the default 'explicit' checkpointing which "
|
||||
"stores the internal input pipeline state in the checkpoint. "
|
||||
"Note that symbolic checkpointing is not supported for "
|
||||
"transformations that can reorder elements.")
|
||||
|
||||
experimental_threading = options_lib.create_option(
|
||||
name="experimental_threading",
|
||||
ty=ThreadingOptions,
|
||||
|
|
@ -622,6 +633,8 @@ class Options(options_lib.OptionsBase):
|
|||
pb.optimization_options.CopyFrom(self.experimental_optimization._to_proto()) # pylint: disable=protected-access
|
||||
if self.experimental_slack is not None:
|
||||
pb.slack = self.experimental_slack
|
||||
if self.experimental_symbolic_checkpoint is not None:
|
||||
pb.symbolic_checkpoint = self.experimental_symbolic_checkpoint
|
||||
pb.threading_options.CopyFrom(self.threading._to_proto()) # pylint: disable=protected-access
|
||||
return pb
|
||||
|
||||
|
|
@ -637,6 +650,8 @@ class Options(options_lib.OptionsBase):
|
|||
self.experimental_optimization._from_proto(pb.optimization_options) # pylint: disable=protected-access
|
||||
if pb.WhichOneof("optional_slack") is not None:
|
||||
self.experimental_slack = pb.slack
|
||||
if pb.WhichOneof("optional_symbolic_checkpoint") is not None:
|
||||
self.experimental_symbolic_checkpoint = pb.symbolic_checkpoint
|
||||
self.threading._from_proto(pb.threading_options) # pylint: disable=protected-access
|
||||
|
||||
def _set_mutable(self, mutable):
|
||||
|
|
|
|||
|
|
@ -31,6 +31,10 @@ tf_class {
|
|||
name: "experimental_slack"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_symbolic_checkpoint"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_threading"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
|||
|
|
@ -31,6 +31,10 @@ tf_class {
|
|||
name: "experimental_slack"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_symbolic_checkpoint"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_threading"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user