mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[tf.data] Iterator Save and Restore for Dataset.from_tensors(..), Dataset.from_tensor_slices(..) and dataset.concatenate(..).
PiperOrigin-RevId: 173971324
This commit is contained in:
parent
09f62ab38b
commit
72be26dc82
|
|
@ -74,9 +74,12 @@ py_test(
|
|||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:iterator_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
|
|
@ -93,6 +96,7 @@ py_test(
|
|||
],
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:iterator_ops",
|
||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
|
@ -104,6 +108,7 @@ py_test(
|
|||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
|
|
@ -241,6 +246,7 @@ py_test(
|
|||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:iterator_ops",
|
||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
|
@ -248,6 +254,7 @@ py_test(
|
|||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:lookup_ops",
|
||||
|
|
@ -255,6 +262,7 @@ py_test(
|
|||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//third_party/py/numpy",
|
||||
|
|
@ -396,10 +404,14 @@ py_test(
|
|||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:iterator_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,13 +17,17 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.contrib.data.python.ops import iterator_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
class ConcatenateDatasetTest(test.TestCase):
|
||||
|
|
@ -129,6 +133,140 @@ class ConcatenateDatasetTest(test.TestCase):
|
|||
with self.assertRaisesRegexp(TypeError, "have different types"):
|
||||
input_dataset.concatenate(dataset_to_concatenate)
|
||||
|
||||
def _iterator_checkpoint_prefix(self):
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def _build_graph(self, input_components, to_concatenate_components):
|
||||
input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
|
||||
dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
|
||||
to_concatenate_components)
|
||||
iterator = input_dataset.concatenate(
|
||||
dataset_to_concatenate).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
saveable = iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
# TODO(shivaniagrawal) : non-intuitive way, add support in mata_graph
|
||||
for t in nest.flatten(get_next):
|
||||
ops.add_to_collection("get_next", t)
|
||||
return init_op, get_next
|
||||
|
||||
def _testSaveRestoreUtility(self, start, break_range, stop):
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
step = 0
|
||||
meta_filename = path + "-%d.meta" % step
|
||||
|
||||
input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
|
||||
np.array([[12], [13], [14], [15]]), 4))
|
||||
to_concatenate_components = (np.tile(
|
||||
np.array([[5], [6], [7], [8], [9]]), 20), np.tile(
|
||||
np.array([[16], [17], [18], [19], [20]]), 15))
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next = self._build_graph(input_components,
|
||||
to_concatenate_components)
|
||||
saver = saver_lib.Saver()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(start, break_range):
|
||||
result = sess.run(get_next)
|
||||
if i < 4:
|
||||
for component, result_component in zip(input_components, result):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
else:
|
||||
for component, result_component in zip(to_concatenate_components,
|
||||
result):
|
||||
self.assertAllEqual(component[i - 4], result_component)
|
||||
saver.save(sess, path, step)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
saver = saver_lib.import_meta_graph(meta_filename)
|
||||
with self.test_session(graph=g) as sess:
|
||||
get_next = nest.pack_sequence_as(("a", "b"),
|
||||
ops.get_collection("get_next"))
|
||||
saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
|
||||
for i in range(break_range, stop):
|
||||
result = sess.run(get_next)
|
||||
if i < 4:
|
||||
for component, result_component in zip(input_components, result):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
else:
|
||||
for component, result_component in zip(to_concatenate_components,
|
||||
result):
|
||||
self.assertAllEqual(component[i - 4], result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testRestoreAtFirstDataset(self):
|
||||
start = 0
|
||||
stop = 9
|
||||
break_range = 3
|
||||
self._testSaveRestoreUtility(start, break_range, stop)
|
||||
|
||||
def testRestoreAtSecondDataset(self):
|
||||
start = 0
|
||||
stop = 9
|
||||
break_range = 6
|
||||
self._testSaveRestoreUtility(start, break_range, stop)
|
||||
|
||||
def testRestoreAtBetweenDatasets(self):
|
||||
start = 0
|
||||
stop = 9
|
||||
break_range = 4
|
||||
self._testSaveRestoreUtility(start, break_range, stop)
|
||||
|
||||
def testRestoreExhaustedIterator(self):
|
||||
start = 0
|
||||
stop = 9
|
||||
break_range = 9
|
||||
self._testSaveRestoreUtility(start, break_range, stop)
|
||||
|
||||
def testRestoreInModifiedGraph(self):
|
||||
start = 0
|
||||
stop = 9
|
||||
break_range = 6
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
step = 0
|
||||
|
||||
input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
|
||||
np.array([[12], [13], [14], [15]]), 4))
|
||||
to_concatenate_components = (np.tile(
|
||||
np.array([[5], [6], [7], [8], [9]]), 20), np.tile(
|
||||
np.array([[16], [17], [18], [19], [20]]), 15))
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next = self._build_graph(input_components,
|
||||
to_concatenate_components)
|
||||
saver = saver_lib.Saver(allow_empty=True)
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(start, break_range):
|
||||
result = sess.run(get_next)
|
||||
if i < 4:
|
||||
for component, result_component in zip(input_components, result):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
else:
|
||||
for component, result_component in zip(to_concatenate_components,
|
||||
result):
|
||||
self.assertAllEqual(component[i - 4], result_component)
|
||||
saver.save(sess, path, step)
|
||||
|
||||
new_to_concatenate_components = (np.array([[5], [6], [7], [8], [9]]),
|
||||
np.array([[16], [17], [18], [19], [20]]))
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next = self._build_graph(input_components,
|
||||
new_to_concatenate_components)
|
||||
saver = saver_lib.Saver()
|
||||
with self.test_session(graph=g) as sess:
|
||||
saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
|
||||
for i in range(break_range, stop):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(to_concatenate_components,
|
||||
result):
|
||||
self.assertAllEqual(component[i - 4], result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import batching
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.contrib.data.python.ops import iterator_ops
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.util import nest
|
||||
|
|
@ -34,6 +36,7 @@ from tensorflow.python.ops import array_ops
|
|||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
class DatasetConstructorTest(test.TestCase):
|
||||
|
|
@ -571,6 +574,136 @@ class DatasetConstructorTest(test.TestCase):
|
|||
new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _iterator_checkpoint_prefix(self):
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def _testSaveRestoreFromTensorsUtility(self, start, break_range, stop):
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
step = 0
|
||||
meta_filename = path + "-%d.meta" % step
|
||||
|
||||
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensors(components)
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
saveable = iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
for t in nest.flatten(get_next):
|
||||
ops.add_to_collection("get_next", t)
|
||||
saver = saver_lib.Saver()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(start, break_range):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component, result_component)
|
||||
saver.save(sess, path, step)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
saver = saver_lib.import_meta_graph(meta_filename)
|
||||
with self.test_session(graph=g) as sess:
|
||||
get_next = nest.pack_sequence_as(("a", "b", "c"),
|
||||
ops.get_collection("get_next"))
|
||||
saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
|
||||
for _ in range(break_range, stop):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testRestoreFromTensors(self):
|
||||
self._testSaveRestoreFromTensorsUtility(0, 0, 1)
|
||||
|
||||
def testRestoreExhuatedIteratorFromTensors(self):
|
||||
self._testSaveRestoreFromTensorsUtility(0, 1, 1)
|
||||
|
||||
def _build_graph_tensor_slices(self, components):
|
||||
iterator = dataset_ops.Dataset.from_tensor_slices(
|
||||
components).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
saveable = iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
for t in nest.flatten(get_next):
|
||||
ops.add_to_collection("get_next", t)
|
||||
return init_op, get_next
|
||||
|
||||
def _testSaveRestoreFromTensorSlicesUtility(self, start, break_range, stop):
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
step = 0
|
||||
meta_filename = path + "-%d.meta" % step
|
||||
|
||||
components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
|
||||
np.array([[12], [13], [14], [15]]), 22),
|
||||
np.array([37.0, 38.0, 39.0, 40.0]))
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next = self._build_graph_tensor_slices(components)
|
||||
saver = saver_lib.Saver()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(start, break_range):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
saver.save(sess, path, step)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
saver = saver_lib.import_meta_graph(meta_filename)
|
||||
with self.test_session(graph=g) as sess:
|
||||
get_next = nest.pack_sequence_as(("a", "b", "c"),
|
||||
ops.get_collection("get_next"))
|
||||
saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
|
||||
for i in range(break_range, stop):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testRestoreFromTensorSlices(self):
|
||||
self._testSaveRestoreFromTensorSlicesUtility(0, 4, 2)
|
||||
|
||||
def testRestoreExhaustedIteratorFromTensorSlices(self):
|
||||
self._testSaveRestoreFromTensorSlicesUtility(0, 4, 4)
|
||||
|
||||
def tesRestoreFromTensorSlicesWithDict(self):
|
||||
|
||||
path = self._iterator_checkpoint_prefix()
|
||||
step = 0
|
||||
meta_filename = path + "-%d.meta" % step
|
||||
|
||||
components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next = self._build_graph_tensor_slices(components)
|
||||
saver = saver_lib.Saver()
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(2):
|
||||
results = sess.run(get_next)
|
||||
self.assertEqual(components["foo"][i], results["foo"])
|
||||
self.assertEqual(components["bar"][i], results["bar"])
|
||||
saver.save(sess, path, step)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
saver = saver_lib.import_meta_graph(meta_filename)
|
||||
with self.test_session(graph=g) as sess:
|
||||
get_next = nest.pack_sequence_as(("a", "b"),
|
||||
ops.get_collection("get_next"))
|
||||
saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
|
||||
for i in range(2, 3):
|
||||
results = sess.run(get_next)
|
||||
self.assertEqual(components["foo"][i], results["foo"])
|
||||
self.assertEqual(components["bar"][i], results["bar"])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -36,15 +36,17 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
|
|||
" have different output_types %s and %s",
|
||||
(DataTypeVectorString(input->output_dtypes()),
|
||||
DataTypeVectorString(to_concatenate->output_dtypes()))));
|
||||
*output = new Dataset(input, to_concatenate);
|
||||
*output = new Dataset(ctx, input, to_concatenate);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
class Dataset : public GraphDatasetBase {
|
||||
public:
|
||||
explicit Dataset(const DatasetBase* input,
|
||||
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
const DatasetBase* to_concatenate)
|
||||
: input_(input), to_concatenate_(to_concatenate) {
|
||||
: GraphDatasetBase(ctx),
|
||||
input_(input),
|
||||
to_concatenate_(to_concatenate) {
|
||||
input_->Ref();
|
||||
to_concatenate_->Ref();
|
||||
|
||||
|
|
@ -76,6 +78,19 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
|
|||
|
||||
string DebugString() override { return "ConcatenateDatasetOp::Dataset"; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph));
|
||||
Node* to_concatenate_graph = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddParentDataset(to_concatenate_, &to_concatenate_graph));
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {input_graph, to_concatenate_graph}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
|
|
@ -105,6 +120,30 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
|
||||
if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2))
|
||||
return errors::InvalidArgument("i_ must be in range [0, 2].");
|
||||
if (i_ == 1) {
|
||||
input_impl_ = dataset()->to_concatenate_->MakeIterator(
|
||||
strings::StrCat(prefix(), "[1]"));
|
||||
} else if (i_ == 2) {
|
||||
input_impl_.reset();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
int64 i_ GUARDED_BY(mu_);
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class IteratorStateReader {
|
|||
// Used for saving iterator state.
|
||||
class IteratorStateWriter {
|
||||
public:
|
||||
virtual Status WriteScalar(StringPiece key, const int64& val) = 0;
|
||||
virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
|
||||
virtual Status WriteScalar(StringPiece key, const string& val) = 0;
|
||||
|
||||
virtual ~IteratorStateWriter() {}
|
||||
|
|
@ -75,10 +75,7 @@ class GraphDefBuilderWrapper {
|
|||
Status AddScalar(const T& val, Node** output) {
|
||||
Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
|
||||
val_t.scalar<T>()() = val;
|
||||
*output =
|
||||
ops::SourceOp("Const", b_->opts()
|
||||
.WithAttr("dtype", DataTypeToEnum<T>::v())
|
||||
.WithAttr("value", val_t));
|
||||
AddTensorInternal(val_t, output);
|
||||
if (*output == nullptr) {
|
||||
return errors::Internal("AddScalar: Failed to build Const op.");
|
||||
}
|
||||
|
|
@ -96,16 +93,25 @@ class GraphDefBuilderWrapper {
|
|||
for (int i = 0; i < val.size(); i++) {
|
||||
val_t.flat<T>()(i) = val[i];
|
||||
}
|
||||
*output =
|
||||
ops::SourceOp("Const", b_->opts()
|
||||
.WithAttr("dtype", DataTypeToEnum<T>::v())
|
||||
.WithAttr("value", val_t));
|
||||
AddTensorInternal(val_t, output);
|
||||
if (*output == nullptr) {
|
||||
return errors::Internal("AddVector: Failed to build Const op.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds a Const node with Tensor value to the Graph.
|
||||
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
|
||||
// non-null if the method returns with an OK status.
|
||||
// The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
|
||||
Status AddTensor(const Tensor& val, Node** output) {
|
||||
AddTensorInternal(val, output);
|
||||
if (*output == nullptr) {
|
||||
return errors::Internal("AddTesor: Failed to build Const op.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds a node corresponding to the `DatasetType` to the Graph.
|
||||
// Return value of `DatasetType::op_name()` is used as the op type for the
|
||||
// node.
|
||||
|
|
@ -148,7 +154,46 @@ class GraphDefBuilderWrapper {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(shivaniagrawal): Single method for AddDataset for
|
||||
// NodeOut/ArrraySlice<NodeOut>
|
||||
template <class DatasetType>
|
||||
Status AddDatasetWithInputAsList(const DatasetType* dataset,
|
||||
gtl::ArraySlice<NodeBuilder::NodeOut> input,
|
||||
Node** output) {
|
||||
const string& op_type_name = dataset->op_name();
|
||||
std::unique_ptr<const GraphDefBuilder::Options> opts(
|
||||
new GraphDefBuilder::Options(b_->opts()));
|
||||
bool has_output_types_attr = HasAttr(op_type_name, "output_types");
|
||||
bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes");
|
||||
if (has_output_shapes_attr) {
|
||||
opts.reset(new GraphDefBuilder::Options(
|
||||
opts->WithAttr("output_shapes", dataset->output_shapes())));
|
||||
}
|
||||
if (has_output_types_attr) {
|
||||
opts.reset(new GraphDefBuilder::Options(
|
||||
opts->WithAttr("output_types", dataset->output_dtypes())));
|
||||
}
|
||||
if (opts->HaveError()) {
|
||||
return errors::Internal("AddDataset: Error building Options.");
|
||||
}
|
||||
NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name,
|
||||
opts->op_registry());
|
||||
node_builder.Input(input);
|
||||
*output = opts->FinalizeBuilder(&node_builder);
|
||||
if (*output == nullptr) {
|
||||
return errors::Internal("AddDataset: Failed to build ", op_type_name,
|
||||
" op.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
void AddTensorInternal(const Tensor& val, Node** output) {
|
||||
*output = ops::SourceOp(
|
||||
"Const",
|
||||
b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
|
||||
}
|
||||
|
||||
bool HasAttr(const string& op_type_name, const string& attr_name) {
|
||||
const OpDef* op_def = nullptr;
|
||||
Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def);
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ class VariantTensorDataWriter : public IteratorStateWriter {
|
|||
// Does not take ownership of data.
|
||||
explicit VariantTensorDataWriter(VariantTensorData* data) : data_(data) {}
|
||||
|
||||
Status WriteScalar(StringPiece key, const int64& val) override {
|
||||
Status WriteScalar(StringPiece key, const int64 val) override {
|
||||
return WriteScalarInternal(key, val);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -40,14 +40,14 @@ class TensorDatasetOp : public DatasetOpKernel {
|
|||
for (const Tensor& t : inputs) {
|
||||
components.push_back(t);
|
||||
}
|
||||
*output = new Dataset(std::move(components));
|
||||
*output = new Dataset(ctx, std::move(components));
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
class Dataset : public GraphDatasetBase {
|
||||
public:
|
||||
explicit Dataset(std::vector<Tensor> tensors)
|
||||
: tensors_(std::move(tensors)) {
|
||||
Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
|
||||
: GraphDatasetBase(ctx), tensors_(std::move(tensors)) {
|
||||
for (const Tensor& t : tensors_) {
|
||||
dtypes_.push_back(t.dtype());
|
||||
shapes_.emplace_back(t.shape().dim_sizes());
|
||||
|
|
@ -67,6 +67,21 @@ class TensorDatasetOp : public DatasetOpKernel {
|
|||
|
||||
string DebugString() override { return "TensorDatasetOp::Dataset"; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
std::vector<NodeBuilder::NodeOut> components;
|
||||
components.reserve(tensors_.size());
|
||||
for (const Tensor& t : tensors_) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
components.emplace_back(node);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDatasetWithInputAsList(this, components, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
|
|
@ -88,6 +103,21 @@ class TensorDatasetOp : public DatasetOpKernel {
|
|||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (produced_)
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("produced"), ""));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
produced_ = reader->Contains(full_name("produced"));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
bool produced_ GUARDED_BY(mu_);
|
||||
|
|
|
|||
|
|
@ -50,14 +50,14 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
|
|||
errors::InvalidArgument(
|
||||
"All components must have the same size in the 0th dimension"));
|
||||
}
|
||||
*output = new Dataset(std::move(components));
|
||||
*output = new Dataset(ctx, std::move(components));
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
class Dataset : public GraphDatasetBase {
|
||||
public:
|
||||
explicit Dataset(std::vector<Tensor> tensors)
|
||||
: tensors_(std::move(tensors)) {
|
||||
explicit Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
|
||||
: GraphDatasetBase(ctx), tensors_(std::move(tensors)) {
|
||||
for (const Tensor& t : tensors_) {
|
||||
dtypes_.push_back(t.dtype());
|
||||
gtl::InlinedVector<int64, 4> partial_dim_sizes;
|
||||
|
|
@ -83,6 +83,21 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
|
|||
|
||||
string DebugString() override { return "TensorSliceDatasetOp::Dataset"; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
std::vector<NodeBuilder::NodeOut> components;
|
||||
components.reserve(tensors_.size());
|
||||
for (const Tensor& t : tensors_) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
components.emplace_back(node);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDatasetWithInputAsList(this, components, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
static Status HandleSliceToElement(const Tensor& parent, Tensor* element,
|
||||
|
|
@ -148,10 +163,24 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
int i_ GUARDED_BY(mu_);
|
||||
const int n_;
|
||||
int64 i_ GUARDED_BY(mu_);
|
||||
const int64 n_;
|
||||
};
|
||||
|
||||
const std::vector<Tensor> tensors_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user