mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add a mechanism for switching between multiple iterators by feeding a handle.
With this change, you can do the following:
1. Fetch a string handle for any iterator, by evaluating the result of
`Iterator.string_handle()`.
2. Define an `Iterator` object based on a `tf.string` placeholder handle.
3. Feed the placeholder using an evaluated string handle to use a particular
iterator in a particular step.
Concretely, this allows you to define two iterators for a training dataset and
a test dataset, and choose which one to use on a per-run basis:
```python
train_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator()
train_iterator_handle = sess.run(train_iterator.string_handle())
test_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator()
test_iterator_handle = sess.run(test_iterator.string_handle())
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.contrib.data.Iterator.from_string_handle(
handle, train_iterator.output_types)
next_element = iterator.get_next()
loss = f(next_element)
train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
```
PiperOrigin-RevId: 161719836
This commit is contained in:
parent
6d6dda807c
commit
71c4ec8ed6
|
|
@ -328,6 +328,54 @@ class IteratorTest(test.TestCase):
|
||||||
[1, 2, 3], dtype=dtypes.int64), constant_op.constant(
|
[1, 2, 3], dtype=dtypes.int64), constant_op.constant(
|
||||||
[4., 5., 6., 7.], dtype=dtypes.float64))))
|
[4., 5., 6., 7.], dtype=dtypes.float64))))
|
||||||
|
|
||||||
|
def testIteratorStringHandle(self):
|
||||||
|
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||||
|
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
|
||||||
|
|
||||||
|
iterator_3 = dataset_3.make_one_shot_iterator()
|
||||||
|
iterator_4 = dataset_4.make_one_shot_iterator()
|
||||||
|
|
||||||
|
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||||
|
feedable_iterator = dataset_ops.Iterator.from_string_handle(
|
||||||
|
handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
|
||||||
|
next_element = feedable_iterator.get_next()
|
||||||
|
|
||||||
|
self.assertEqual(dataset_3.output_types, feedable_iterator.output_types)
|
||||||
|
self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
|
||||||
|
self.assertEqual([], feedable_iterator.output_shapes)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
iterator_3_handle = sess.run(iterator_3.string_handle())
|
||||||
|
iterator_4_handle = sess.run(iterator_4.string_handle())
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
10, sess.run(next_element,
|
||||||
|
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||||
|
self.assertEqual(
|
||||||
|
1, sess.run(next_element,
|
||||||
|
feed_dict={handle_placeholder: iterator_3_handle}))
|
||||||
|
self.assertEqual(
|
||||||
|
20, sess.run(next_element,
|
||||||
|
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||||
|
self.assertEqual(
|
||||||
|
2, sess.run(next_element,
|
||||||
|
feed_dict={handle_placeholder: iterator_3_handle}))
|
||||||
|
self.assertEqual(
|
||||||
|
30, sess.run(next_element,
|
||||||
|
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||||
|
self.assertEqual(
|
||||||
|
3, sess.run(next_element,
|
||||||
|
feed_dict={handle_placeholder: iterator_3_handle}))
|
||||||
|
self.assertEqual(
|
||||||
|
40, sess.run(next_element,
|
||||||
|
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(next_element,
|
||||||
|
feed_dict={handle_placeholder: iterator_3_handle})
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(next_element,
|
||||||
|
feed_dict={handle_placeholder: iterator_4_handle})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
||||||
|
|
@ -182,6 +182,62 @@ class Iterator(object):
|
||||||
output_shapes=nest.flatten(output_shapes))
|
output_shapes=nest.flatten(output_shapes))
|
||||||
return Iterator(iterator_resource, None, output_types, output_shapes)
|
return Iterator(iterator_resource, None, output_types, output_shapes)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_string_handle(string_handle, output_types, output_shapes=None):
|
||||||
|
"""Creates a new, uninitialized `Iterator` based on the given handle.
|
||||||
|
|
||||||
|
This method allows you to define a "feedable" iterator where you can choose
|
||||||
|
between concrete iterators by feeding a value in a @{tf.Session.run} call.
|
||||||
|
In that case, `string_handle` would a @{tf.placeholder}, and you would feed
|
||||||
|
it with the value of @{tf.contrib.data.Iterator.string_handle} in each step.
|
||||||
|
|
||||||
|
For example, if you had two iterators that marked the current position in
|
||||||
|
a training dataset and a test dataset, you could choose which to use in
|
||||||
|
each step as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
train_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator()
|
||||||
|
train_iterator_handle = sess.run(train_iterator.string_handle())
|
||||||
|
|
||||||
|
test_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator()
|
||||||
|
test_iterator_handle = sess.run(test_iterator.string_handle())
|
||||||
|
|
||||||
|
handle = tf.placeholder(tf.string, shape=[])
|
||||||
|
iterator = tf.contrib.data.Iterator.from_string_handle(
|
||||||
|
handle, train_iterator.output_types)
|
||||||
|
|
||||||
|
next_element = iterator.get_next()
|
||||||
|
loss = f(next_element)
|
||||||
|
|
||||||
|
train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
|
||||||
|
test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
|
||||||
|
to a handle produced by the `Iterator.string_handle()` method.
|
||||||
|
output_types: A nested structure of `tf.DType` objects corresponding to
|
||||||
|
each component of an element of this iterator.
|
||||||
|
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
|
||||||
|
corresponding to each component of an element of this dataset. If
|
||||||
|
omitted, each component will have an unconstrainted shape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An `Iterator`.
|
||||||
|
"""
|
||||||
|
output_types = nest.map_structure(dtypes.as_dtype, output_types)
|
||||||
|
if output_shapes is None:
|
||||||
|
output_shapes = nest.map_structure(
|
||||||
|
lambda _: tensor_shape.TensorShape(None), output_types)
|
||||||
|
else:
|
||||||
|
output_shapes = nest.map_structure_up_to(
|
||||||
|
output_types, tensor_shape.as_shape, output_shapes)
|
||||||
|
nest.assert_same_structure(output_types, output_shapes)
|
||||||
|
string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
|
||||||
|
iterator_resource = gen_dataset_ops.iterator_from_string_handle(
|
||||||
|
string_handle)
|
||||||
|
return Iterator(iterator_resource, None, output_types, output_shapes)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def initializer(self):
|
def initializer(self):
|
||||||
"""A `tf.Operation` that should be run to initialize this iterator.
|
"""A `tf.Operation` that should be run to initialize this iterator.
|
||||||
|
|
@ -261,6 +317,18 @@ class Iterator(object):
|
||||||
"""
|
"""
|
||||||
return gen_dataset_ops.iterator_dispose(self._iterator_resource, name=name)
|
return gen_dataset_ops.iterator_dispose(self._iterator_resource, name=name)
|
||||||
|
|
||||||
|
def string_handle(self, name=None):
|
||||||
|
"""Returns a string-valued `tf.Tensor` that represents this iterator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: (Optional.) A name for the created operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A scalar `tf.Tensor` of type `tf.string`.
|
||||||
|
"""
|
||||||
|
return gen_dataset_ops.iterator_to_string_handle(self._iterator_resource,
|
||||||
|
name=name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_shapes(self):
|
def output_shapes(self):
|
||||||
"""Returns the shape of each component of an element of this iterator.
|
"""Returns the shape of each component of an element of this iterator.
|
||||||
|
|
|
||||||
|
|
@ -415,6 +415,69 @@ class IteratorDisposeOp : public OpKernel {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class IteratorToStringHandleOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
|
||||||
|
: OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const Tensor& resource_handle_t = ctx->input(0);
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
||||||
|
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||||
|
|
||||||
|
// Validate that the handle corresponds to a real resource, and
|
||||||
|
// that it is an IteratorResource.
|
||||||
|
IteratorResource* iterator_resource;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
||||||
|
iterator_resource->Unref();
|
||||||
|
|
||||||
|
Tensor* string_handle_t;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->allocate_output(0, TensorShape({}), &string_handle_t));
|
||||||
|
string_handle_t->scalar<string>()() =
|
||||||
|
resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class IteratorFromStringHandleOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx)
|
||||||
|
: OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const Tensor& string_handle_t = ctx->input(0);
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
|
||||||
|
errors::InvalidArgument("string_handle must be a scalar"));
|
||||||
|
|
||||||
|
ResourceHandle resource_handle;
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx,
|
||||||
|
resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Could not parse string_handle as a valid ResourceHandle"));
|
||||||
|
|
||||||
|
// Validate that the handle corresponds to a real resource, and
|
||||||
|
// that it is an IteratorResource.
|
||||||
|
IteratorResource* iterator_resource;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
LookupResource(ctx, resource_handle, &iterator_resource));
|
||||||
|
iterator_resource->Unref();
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, resource_handle.device() == ctx->device()->attributes().name(),
|
||||||
|
errors::InvalidArgument("Attempted create an iterator on device \"",
|
||||||
|
ctx->device()->attributes().name(),
|
||||||
|
"\" from handle defined on device \"",
|
||||||
|
resource_handle.device(), "\""));
|
||||||
|
|
||||||
|
Tensor* resource_handle_t;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
|
||||||
|
resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
|
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
|
||||||
MakeIteratorOp);
|
MakeIteratorOp);
|
||||||
|
|
@ -424,6 +487,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
|
||||||
IteratorGetNextOp);
|
IteratorGetNextOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("IteratorDispose").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("IteratorDispose").Device(DEVICE_CPU),
|
||||||
IteratorDisposeOp);
|
IteratorDisposeOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
|
||||||
|
IteratorToStringHandleOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
|
||||||
|
IteratorFromStringHandleOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -533,4 +533,26 @@ REGISTER_OP("IteratorDispose")
|
||||||
Releases any resources used by the given iterator.
|
Releases any resources used by the given iterator.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("IteratorToStringHandle")
|
||||||
|
.Input("resource_handle: resource")
|
||||||
|
.Output("string_handle: string")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Converts the given `resource_handle` representing an iterator to a string.
|
||||||
|
|
||||||
|
resource_handle: A handle to an iterator resource.
|
||||||
|
string_handle: A string representation of the given handle.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("IteratorFromStringHandle")
|
||||||
|
.Input("string_handle: string")
|
||||||
|
.Output("resource_handle: resource")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Converts the given string representing a handle to an iterator to a resource.
|
||||||
|
|
||||||
|
string_handle: A string representation of the given handle.
|
||||||
|
resource_handle: A handle to an iterator resource.
|
||||||
|
)doc");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user