[tf.data] Iterator Save and Restore for Dataset.from_tensors(..), Dataset.from_tensor_slices(..) and dataset.concatenate(..).

PiperOrigin-RevId: 173971324
This commit is contained in:
A. Unique TensorFlower 2017-10-30 16:56:03 -07:00 committed by TensorFlower Gardener
parent 09f62ab38b
commit 72be26dc82
8 changed files with 450 additions and 24 deletions

View File

@ -74,9 +74,12 @@ py_test(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:training",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
@ -93,6 +96,7 @@ py_test(
],
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@ -104,6 +108,7 @@ py_test(
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:training",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
@ -241,6 +246,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@ -248,6 +254,7 @@ py_test(
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:functional_ops",
"//tensorflow/python:io_ops",
"//tensorflow/python:lookup_ops",
@ -255,6 +262,7 @@ py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//third_party/py/numpy",
@ -396,10 +404,14 @@ py_test(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
)

View File

@ -17,13 +17,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
class ConcatenateDatasetTest(test.TestCase):
@ -129,6 +133,140 @@ class ConcatenateDatasetTest(test.TestCase):
with self.assertRaisesRegexp(TypeError, "have different types"):
input_dataset.concatenate(dataset_to_concatenate)
def _iterator_checkpoint_prefix(self):
return os.path.join(self.get_temp_dir(), "iterator")
def _build_graph(self, input_components, to_concatenate_components):
input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
to_concatenate_components)
iterator = input_dataset.concatenate(
dataset_to_concatenate).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
saveable = iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
# TODO(shivaniagrawal) : non-intuitive way, add support in mata_graph
for t in nest.flatten(get_next):
ops.add_to_collection("get_next", t)
return init_op, get_next
def _testSaveRestoreUtility(self, start, break_range, stop):
path = self._iterator_checkpoint_prefix()
step = 0
meta_filename = path + "-%d.meta" % step
input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
np.array([[12], [13], [14], [15]]), 4))
to_concatenate_components = (np.tile(
np.array([[5], [6], [7], [8], [9]]), 20), np.tile(
np.array([[16], [17], [18], [19], [20]]), 15))
with ops.Graph().as_default() as g:
init_op, get_next = self._build_graph(input_components,
to_concatenate_components)
saver = saver_lib.Saver()
with self.test_session(graph=g) as sess:
sess.run(init_op)
for i in range(start, break_range):
result = sess.run(get_next)
if i < 4:
for component, result_component in zip(input_components, result):
self.assertAllEqual(component[i], result_component)
else:
for component, result_component in zip(to_concatenate_components,
result):
self.assertAllEqual(component[i - 4], result_component)
saver.save(sess, path, step)
with ops.Graph().as_default() as g:
saver = saver_lib.import_meta_graph(meta_filename)
with self.test_session(graph=g) as sess:
get_next = nest.pack_sequence_as(("a", "b"),
ops.get_collection("get_next"))
saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
for i in range(break_range, stop):
result = sess.run(get_next)
if i < 4:
for component, result_component in zip(input_components, result):
self.assertAllEqual(component[i], result_component)
else:
for component, result_component in zip(to_concatenate_components,
result):
self.assertAllEqual(component[i - 4], result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testRestoreAtFirstDataset(self):
start = 0
stop = 9
break_range = 3
self._testSaveRestoreUtility(start, break_range, stop)
def testRestoreAtSecondDataset(self):
start = 0
stop = 9
break_range = 6
self._testSaveRestoreUtility(start, break_range, stop)
def testRestoreAtBetweenDatasets(self):
start = 0
stop = 9
break_range = 4
self._testSaveRestoreUtility(start, break_range, stop)
def testRestoreExhaustedIterator(self):
start = 0
stop = 9
break_range = 9
self._testSaveRestoreUtility(start, break_range, stop)
def testRestoreInModifiedGraph(self):
start = 0
stop = 9
break_range = 6
path = self._iterator_checkpoint_prefix()
step = 0
input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
np.array([[12], [13], [14], [15]]), 4))
to_concatenate_components = (np.tile(
np.array([[5], [6], [7], [8], [9]]), 20), np.tile(
np.array([[16], [17], [18], [19], [20]]), 15))
with ops.Graph().as_default() as g:
init_op, get_next = self._build_graph(input_components,
to_concatenate_components)
saver = saver_lib.Saver(allow_empty=True)
with self.test_session(graph=g) as sess:
sess.run(init_op)
for i in range(start, break_range):
result = sess.run(get_next)
if i < 4:
for component, result_component in zip(input_components, result):
self.assertAllEqual(component[i], result_component)
else:
for component, result_component in zip(to_concatenate_components,
result):
self.assertAllEqual(component[i - 4], result_component)
saver.save(sess, path, step)
new_to_concatenate_components = (np.array([[5], [6], [7], [8], [9]]),
np.array([[16], [17], [18], [19], [20]]))
with ops.Graph().as_default() as g:
init_op, get_next = self._build_graph(input_components,
new_to_concatenate_components)
saver = saver_lib.Saver()
with self.test_session(graph=g) as sess:
saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
for i in range(break_range, stop):
result = sess.run(get_next)
for component, result_component in zip(to_concatenate_components,
result):
self.assertAllEqual(component[i - 4], result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -17,12 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import threading
import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import iterator_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.util import nest
@ -34,6 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
class DatasetConstructorTest(test.TestCase):
@ -571,6 +574,136 @@ class DatasetConstructorTest(test.TestCase):
new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
# pylint: enable=protected-access
def _iterator_checkpoint_prefix(self):
return os.path.join(self.get_temp_dir(), "iterator")
def _testSaveRestoreFromTensorsUtility(self, start, break_range, stop):
path = self._iterator_checkpoint_prefix()
step = 0
meta_filename = path + "-%d.meta" % step
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
with ops.Graph().as_default() as g:
iterator = (
dataset_ops.Dataset.from_tensors(components)
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
saveable = iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
for t in nest.flatten(get_next):
ops.add_to_collection("get_next", t)
saver = saver_lib.Saver()
with self.test_session(graph=g) as sess:
sess.run(init_op)
for _ in range(start, break_range):
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component, result_component)
saver.save(sess, path, step)
with ops.Graph().as_default() as g:
saver = saver_lib.import_meta_graph(meta_filename)
with self.test_session(graph=g) as sess:
get_next = nest.pack_sequence_as(("a", "b", "c"),
ops.get_collection("get_next"))
saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
for _ in range(break_range, stop):
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testRestoreFromTensors(self):
self._testSaveRestoreFromTensorsUtility(0, 0, 1)
def testRestoreExhuatedIteratorFromTensors(self):
self._testSaveRestoreFromTensorsUtility(0, 1, 1)
def _build_graph_tensor_slices(self, components):
iterator = dataset_ops.Dataset.from_tensor_slices(
components).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
saveable = iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
for t in nest.flatten(get_next):
ops.add_to_collection("get_next", t)
return init_op, get_next
def _testSaveRestoreFromTensorSlicesUtility(self, start, break_range, stop):
path = self._iterator_checkpoint_prefix()
step = 0
meta_filename = path + "-%d.meta" % step
components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
np.array([[12], [13], [14], [15]]), 22),
np.array([37.0, 38.0, 39.0, 40.0]))
with ops.Graph().as_default() as g:
init_op, get_next = self._build_graph_tensor_slices(components)
saver = saver_lib.Saver()
with self.test_session(graph=g) as sess:
sess.run(init_op)
for i in range(start, break_range):
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i], result_component)
saver.save(sess, path, step)
with ops.Graph().as_default() as g:
saver = saver_lib.import_meta_graph(meta_filename)
with self.test_session(graph=g) as sess:
get_next = nest.pack_sequence_as(("a", "b", "c"),
ops.get_collection("get_next"))
saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
for i in range(break_range, stop):
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i], result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testRestoreFromTensorSlices(self):
self._testSaveRestoreFromTensorSlicesUtility(0, 4, 2)
def testRestoreExhaustedIteratorFromTensorSlices(self):
self._testSaveRestoreFromTensorSlicesUtility(0, 4, 4)
def tesRestoreFromTensorSlicesWithDict(self):
path = self._iterator_checkpoint_prefix()
step = 0
meta_filename = path + "-%d.meta" % step
components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
with ops.Graph().as_default() as g:
init_op, get_next = self._build_graph_tensor_slices(components)
saver = saver_lib.Saver()
with self.test_session(graph=g) as sess:
sess.run(init_op)
for i in range(2):
results = sess.run(get_next)
self.assertEqual(components["foo"][i], results["foo"])
self.assertEqual(components["bar"][i], results["bar"])
saver.save(sess, path, step)
with ops.Graph().as_default() as g:
saver = saver_lib.import_meta_graph(meta_filename)
with self.test_session(graph=g) as sess:
get_next = nest.pack_sequence_as(("a", "b"),
ops.get_collection("get_next"))
saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
for i in range(2, 3):
results = sess.run(get_next)
self.assertEqual(components["foo"][i], results["foo"])
self.assertEqual(components["bar"][i], results["bar"])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -36,15 +36,17 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
" have different output_types %s and %s",
(DataTypeVectorString(input->output_dtypes()),
DataTypeVectorString(to_concatenate->output_dtypes()))));
*output = new Dataset(input, to_concatenate);
*output = new Dataset(ctx, input, to_concatenate);
}
private:
class Dataset : public DatasetBase {
class Dataset : public GraphDatasetBase {
public:
explicit Dataset(const DatasetBase* input,
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
const DatasetBase* to_concatenate)
: input_(input), to_concatenate_(to_concatenate) {
: GraphDatasetBase(ctx),
input_(input),
to_concatenate_(to_concatenate) {
input_->Ref();
to_concatenate_->Ref();
@ -76,6 +78,19 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
string DebugString() override { return "ConcatenateDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph = nullptr;
TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph));
Node* to_concatenate_graph = nullptr;
TF_RETURN_IF_ERROR(
b->AddParentDataset(to_concatenate_, &to_concatenate_graph));
TF_RETURN_IF_ERROR(
b->AddDataset(this, {input_graph, to_concatenate_graph}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
@ -105,6 +120,30 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2))
return errors::InvalidArgument("i_ must be in range [0, 2].");
if (i_ == 1) {
input_impl_ = dataset()->to_concatenate_->MakeIterator(
strings::StrCat(prefix(), "[1]"));
} else if (i_ == 2) {
input_impl_.reset();
}
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
return Status::OK();
}
private:
mutex mu_;
int64 i_ GUARDED_BY(mu_);

View File

@ -56,7 +56,7 @@ class IteratorStateReader {
// Used for saving iterator state.
class IteratorStateWriter {
public:
virtual Status WriteScalar(StringPiece key, const int64& val) = 0;
virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
virtual Status WriteScalar(StringPiece key, const string& val) = 0;
virtual ~IteratorStateWriter() {}
@ -75,10 +75,7 @@ class GraphDefBuilderWrapper {
Status AddScalar(const T& val, Node** output) {
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
val_t.scalar<T>()() = val;
*output =
ops::SourceOp("Const", b_->opts()
.WithAttr("dtype", DataTypeToEnum<T>::v())
.WithAttr("value", val_t));
AddTensorInternal(val_t, output);
if (*output == nullptr) {
return errors::Internal("AddScalar: Failed to build Const op.");
}
@ -96,16 +93,25 @@ class GraphDefBuilderWrapper {
for (int i = 0; i < val.size(); i++) {
val_t.flat<T>()(i) = val[i];
}
*output =
ops::SourceOp("Const", b_->opts()
.WithAttr("dtype", DataTypeToEnum<T>::v())
.WithAttr("value", val_t));
AddTensorInternal(val_t, output);
if (*output == nullptr) {
return errors::Internal("AddVector: Failed to build Const op.");
}
return Status::OK();
}
// Adds a Const node with Tensor value to the Graph.
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
// non-null if the method returns with an OK status.
// The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
Status AddTensor(const Tensor& val, Node** output) {
AddTensorInternal(val, output);
if (*output == nullptr) {
return errors::Internal("AddTesor: Failed to build Const op.");
}
return Status::OK();
}
// Adds a node corresponding to the `DatasetType` to the Graph.
// Return value of `DatasetType::op_name()` is used as the op type for the
// node.
@ -148,7 +154,46 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
// TODO(shivaniagrawal): Single method for AddDataset for
// NodeOut/ArrraySlice<NodeOut>
template <class DatasetType>
Status AddDatasetWithInputAsList(const DatasetType* dataset,
gtl::ArraySlice<NodeBuilder::NodeOut> input,
Node** output) {
const string& op_type_name = dataset->op_name();
std::unique_ptr<const GraphDefBuilder::Options> opts(
new GraphDefBuilder::Options(b_->opts()));
bool has_output_types_attr = HasAttr(op_type_name, "output_types");
bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes");
if (has_output_shapes_attr) {
opts.reset(new GraphDefBuilder::Options(
opts->WithAttr("output_shapes", dataset->output_shapes())));
}
if (has_output_types_attr) {
opts.reset(new GraphDefBuilder::Options(
opts->WithAttr("output_types", dataset->output_dtypes())));
}
if (opts->HaveError()) {
return errors::Internal("AddDataset: Error building Options.");
}
NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name,
opts->op_registry());
node_builder.Input(input);
*output = opts->FinalizeBuilder(&node_builder);
if (*output == nullptr) {
return errors::Internal("AddDataset: Failed to build ", op_type_name,
" op.");
}
return Status::OK();
}
private:
void AddTensorInternal(const Tensor& val, Node** output) {
*output = ops::SourceOp(
"Const",
b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
}
bool HasAttr(const string& op_type_name, const string& attr_name) {
const OpDef* op_def = nullptr;
Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def);

View File

@ -228,7 +228,7 @@ class VariantTensorDataWriter : public IteratorStateWriter {
// Does not take ownership of data.
explicit VariantTensorDataWriter(VariantTensorData* data) : data_(data) {}
Status WriteScalar(StringPiece key, const int64& val) override {
Status WriteScalar(StringPiece key, const int64 val) override {
return WriteScalarInternal(key, val);
}

View File

@ -40,14 +40,14 @@ class TensorDatasetOp : public DatasetOpKernel {
for (const Tensor& t : inputs) {
components.push_back(t);
}
*output = new Dataset(std::move(components));
*output = new Dataset(ctx, std::move(components));
}
private:
class Dataset : public DatasetBase {
class Dataset : public GraphDatasetBase {
public:
explicit Dataset(std::vector<Tensor> tensors)
: tensors_(std::move(tensors)) {
Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
: GraphDatasetBase(ctx), tensors_(std::move(tensors)) {
for (const Tensor& t : tensors_) {
dtypes_.push_back(t.dtype());
shapes_.emplace_back(t.shape().dim_sizes());
@ -67,6 +67,21 @@ class TensorDatasetOp : public DatasetOpKernel {
string DebugString() override { return "TensorDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** output) const override {
std::vector<NodeBuilder::NodeOut> components;
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
components.emplace_back(node);
}
TF_RETURN_IF_ERROR(
b->AddDatasetWithInputAsList(this, components, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
@ -88,6 +103,21 @@ class TensorDatasetOp : public DatasetOpKernel {
}
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (produced_)
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("produced"), ""));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
produced_ = reader->Contains(full_name("produced"));
return Status::OK();
}
private:
mutex mu_;
bool produced_ GUARDED_BY(mu_);

View File

@ -50,14 +50,14 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
errors::InvalidArgument(
"All components must have the same size in the 0th dimension"));
}
*output = new Dataset(std::move(components));
*output = new Dataset(ctx, std::move(components));
}
private:
class Dataset : public DatasetBase {
class Dataset : public GraphDatasetBase {
public:
explicit Dataset(std::vector<Tensor> tensors)
: tensors_(std::move(tensors)) {
explicit Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
: GraphDatasetBase(ctx), tensors_(std::move(tensors)) {
for (const Tensor& t : tensors_) {
dtypes_.push_back(t.dtype());
gtl::InlinedVector<int64, 4> partial_dim_sizes;
@ -83,6 +83,21 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
string DebugString() override { return "TensorSliceDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** output) const override {
std::vector<NodeBuilder::NodeOut> components;
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
components.emplace_back(node);
}
TF_RETURN_IF_ERROR(
b->AddDatasetWithInputAsList(this, components, output));
return Status::OK();
}
private:
template <typename T>
static Status HandleSliceToElement(const Tensor& parent, Tensor* element,
@ -148,10 +163,24 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
return Status::OK();
}
private:
mutex mu_;
int i_ GUARDED_BY(mu_);
const int n_;
int64 i_ GUARDED_BY(mu_);
const int64 n_;
};
const std::vector<Tensor> tensors_;