Introduce tf.data namespace.

PiperOrigin-RevId: 170939033
This commit is contained in:
A. Unique TensorFlower 2017-10-03 17:08:50 -07:00 committed by TensorFlower Gardener
parent 0c8dbc1fda
commit 0068086b9a
14 changed files with 641 additions and 57 deletions

View File

@ -1,8 +1,10 @@
`tf.contrib.data` API
=====================
NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead.
This directory contains the Python API for the `tf.contrib.data.Dataset` and
`tf.contrib.data.Iterator` classes, which can be used to build input pipelines.
The documentation for this API has moved to the programmers'
The documentation for `tf.data` API has moved to the programmers'
guide, [here](../../docs_src/programmers_guide/datasets.md).

View File

@ -12,7 +12,7 @@ complicated transformations.
The `Dataset` API introduces two new abstractions to TensorFlow:
* A `tf.contrib.data.Dataset` represents a sequence of elements, in which
* A `tf.data.Dataset` represents a sequence of elements, in which
each element contains one or more `Tensor` objects. For example, in an image
pipeline, an element might be a single training example, with a pair of
tensors representing the image data and a label. There are two distinct
@ -23,9 +23,9 @@ The `Dataset` API introduces two new abstractions to TensorFlow:
one or more `tf.Tensor` objects.
* Applying a **transformation** (e.g. `Dataset.batch()`) constructs a dataset
from one or more `tf.contrib.data.Dataset` objects.
from one or more `tf.data.Dataset` objects.
* A `tf.contrib.data.Iterator` provides the main way to extract elements from a
* A `tf.data.Iterator` provides the main way to extract elements from a
dataset. The operation returned by `Iterator.get_next()` yields the next
element of a `Dataset` when executed, and typically acts as the interface
between input pipeline code and your model. The simplest iterator is a
@ -42,22 +42,22 @@ of `Dataset` and `Iterator` objects, and how to extract data from them.
To start an input pipeline, you must define a *source*. For example, to
construct a `Dataset` from some tensors in memory, you can use
`tf.contrib.data.Dataset.from_tensors()` or
`tf.contrib.data.Dataset.from_tensor_slices()`. Alternatively, if your input
`tf.data.Dataset.from_tensors()` or
`tf.data.Dataset.from_tensor_slices()`. Alternatively, if your input
data are on disk in the recommend TFRecord format, you can construct a
`tf.contrib.data.TFRecordDataset`.
`tf.data.TFRecordDataset`.
Once you have a `Dataset` object, you can *transform* it into a new `Dataset` by
chaining method calls on the `tf.contrib.data.Dataset` object. For example, you
chaining method calls on the `tf.data.Dataset` object. For example, you
can apply per-element transformations such as `Dataset.map()` (to apply a
function to each element), and multi-element transformations such as
`Dataset.batch()`. See the documentation for @{tf.contrib.data.Dataset}
`Dataset.batch()`. See the documentation for @{tf.data.Dataset}
for a complete list of transformations.
The most common way to consume values from a `Dataset` is to make an
**iterator** object that provides access to one element of the dataset at a time
(for example, by calling `Dataset.make_one_shot_iterator()`). A
`tf.contrib.data.Iterator` provides two operations: `Iterator.initializer`,
`tf.data.Iterator` provides two operations: `Iterator.initializer`,
which enables you to (re)initialize the iterator's state; and
`Iterator.get_next()`, which returns `tf.Tensor` objects that correspond to the
symbolic next element. Depending on your use case, you might choose a different
@ -76,17 +76,17 @@ of an element, which may be a single tensor, a tuple of tensors, or a nested
tuple of tensors. For example:
```python
dataset1 = tf.contrib.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types) # ==> "tf.float32"
print(dataset1.output_shapes) # ==> "(10,)"
dataset2 = tf.contrib.data.Dataset.from_tensor_slices(
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes) # ==> "((), (100,))"
dataset3 = tf.contrib.data.Dataset.zip((dataset1, dataset2))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes) # ==> "(10, ((), (100,)))"
```
@ -97,7 +97,7 @@ to tuples, you can use `collections.namedtuple` or a dictionary mapping strings
to tensors to represent a single element of a `Dataset`.
```python
dataset = tf.contrib.data.Dataset.from_tensor_slices(
dataset = tf.data.Dataset.from_tensor_slices(
{"a": tf.random_uniform([4]),
"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}"
@ -137,7 +137,7 @@ input pipelines support, but they do not support parameterization. Using the
example of `Dataset.range()`:
```python
dataset = tf.contrib.data.Dataset.range(100)
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@ -157,7 +157,7 @@ initialize the iterator. Continuing the `Dataset.range()` example:
```python
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.contrib.data.Dataset.range(max_value)
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
@ -183,9 +183,9 @@ structure (i.e. the same types and compatible shapes for each component).
```python
# Define training and validation datasets with the same structure.
training_dataset = tf.contrib.data.Dataset.range(100).map(
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.contrib.data.Dataset.range(50)
validation_dataset = tf.data.Dataset.range(50)
# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
@ -217,21 +217,21 @@ what `Iterator` to use in each call to @{tf.Session.run}, via the familiar
iterator, but it does not require you to initialize the iterator from the start
of a dataset when you switch between iterators. For example, using the same
training and validation example from above, you can use
@{tf.contrib.data.Iterator.from_string_handle} to define a feedable iterator
@{tf.data.Iterator.from_string_handle} to define a feedable iterator
that allows you to switch between the two datasets:
```python
# Define training and validation datasets with the same structure.
training_dataset = tf.contrib.data.Dataset.range(100).map(
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.contrib.data.Dataset.range(50)
validation_dataset = tf.data.Dataset.range(50)
# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.contrib.data.Iterator.from_string_handle(
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
@ -276,7 +276,7 @@ After this point the iterator will be in an unusable state, and you must
initialize it again if you want to use it further.
```python
dataset = tf.contrib.data.Dataset.range(5)
dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
@ -312,9 +312,9 @@ If each element of the dataset has a nested structure, the return value of
nested structure:
```python
dataset1 = tf.contrib.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.contrib.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.contrib.data.Dataset.zip((dataset1, dataset2))
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
@ -343,7 +343,7 @@ with np.load("/var/data/training_data.npy") as data:
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
dataset = tf.contrib.data.Dataset.from_tensor_slices((features, labels))
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
```
Note that the above code snippet will embed the `features` and `labels` arrays
@ -368,7 +368,7 @@ assert features.shape[0] == labels.shape[0]
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.contrib.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()
@ -382,14 +382,14 @@ sess.run(iterator.initializer, feed_dict={features_placeholder: features,
The `Dataset` API supports a variety of file formats so that you can process
large datasets that do not fit in memory. For example, the TFRecord file format
is a simple record-oriented binary format that many TensorFlow applications use
for training data. The `tf.contrib.data.TFRecordDataset` class enables you to
for training data. The `tf.data.TFRecordDataset` class enables you to
stream over the contents of one or more TFRecord files as part of an input
pipeline.
```python
# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = tf.data.TFRecordDataset(filenames)
```
The `filenames` argument to the `TFRecordDataset` initializer can either be a
@ -400,7 +400,7 @@ iterator from the appropriate filenames:
```python
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(32)
@ -421,7 +421,7 @@ sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
### Consuming text data
Many datasets are distributed as one or more text files. The
`tf.contrib.data.TextLineDataset` provides an easy way to extract lines from
`tf.data.TextLineDataset` provides an easy way to extract lines from
one or more text files. Given one or more filenames, a `TextLineDataset` will
produce one string-valued element per line of those files. Like a
`TFRecordDataset`, `TextLineDataset` accepts `filenames` as a `tf.Tensor`, so
@ -429,7 +429,7 @@ you can parameterize it by passing a `tf.placeholder(tf.string)`.
```python
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.contrib.data.TextLineDataset(filenames)
dataset = tf.data.TextLineDataset(filenames)
```
By default, a `TextLineDataset` yields *every* line of each file, which may
@ -442,7 +442,7 @@ each file.
```python
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames)
dataset = tf.data.Dataset.from_tensor_slices(filenames)
# Use `Dataset.flat_map()` to transform each file as a separate nested dataset,
# and then concatenate their contents sequentially into a single "flat" dataset.
@ -450,7 +450,7 @@ dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames)
# * Filter out lines beginning with "#" (comments).
dataset = dataset.flat_map(
lambda filename: (
tf.contrib.data.TextLineDataset(filename)
tf.data.TextLineDataset(filename)
.skip(1)
.filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))
```
@ -498,7 +498,7 @@ def _parse_function(example_proto):
# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
```
@ -523,7 +523,7 @@ filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])
# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])
dataset = tf.contrib.data.Dataset.from_tensor_slices((filenames, labels))
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
```
@ -552,7 +552,7 @@ def _resize_function(image_decoded, label):
filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]
dataset = tf.contrib.data.Dataset.from_tensor_slices((filenames, labels))
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
lambda filename, label: tuple(tf.py_func(
_read_py_function, [filename, label], [tf.uint8, label.dtype])))
@ -576,9 +576,9 @@ of the elements: i.e. for each component *i*, all elements must have a tensor
of the exact same shape.
```python
inc_dataset = tf.contrib.data.Dataset.range(100)
dec_dataset = tf.contrib.data.Dataset.range(0, -100, -1)
dataset = tf.contrib.data.Dataset.zip((inc_dataset, dec_dataset))
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)
iterator = batched_dataset.make_one_shot_iterator()
@ -599,7 +599,7 @@ different shape by specifying one or more dimensions in which they may be
padded.
```python
dataset = tf.contrib.data.Dataset.range(100)
dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])
@ -637,7 +637,7 @@ its input for 10 epochs:
```python
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)
@ -655,7 +655,7 @@ error) for the epoch.
```python
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
@ -681,7 +681,7 @@ buffer and chooses the next element uniformly at random from that buffer.
```python
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
@ -698,7 +698,7 @@ with the `Dataset` API, we recommend using
```python
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
@ -721,7 +721,7 @@ recommend using `Dataset.make_one_shot_iterator()`. For example:
```python
def dataset_input_fn():
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = tf.data.TFRecordDataset(filenames)
# Use `tf.parse_single_example()` to extract data from a `tf.Example`
# protocol buffer, and perform any additional per-record preprocessing.

View File

@ -78,9 +78,10 @@ from tensorflow.python.ops import linalg_ns as linalg
# pylint: enable=wildcard-import
# Bring in subpackages.
from tensorflow.python import data
from tensorflow.python import keras
from tensorflow.python.estimator import estimator_lib as estimator
from tensorflow.python.feature_column import feature_column_lib as feature_column
from tensorflow.python import keras
from tensorflow.python.layers import layers
from tensorflow.python.ops import bitwise_ops as bitwise
from tensorflow.python.ops import image_ops as image
@ -91,10 +92,11 @@ from tensorflow.python.ops import spectral_ops as spectral
from tensorflow.python.ops.distributions import distributions
from tensorflow.python.ops.losses import losses
from tensorflow.python.profiler import profiler
from tensorflow.python.user_ops import user_ops
from tensorflow.python.util import compat
from tensorflow.python.saved_model import saved_model
from tensorflow.python.summary import summary
from tensorflow.python.user_ops import user_ops
from tensorflow.python.util import compat
# Import the names from python/training.py as train.Name.
from tensorflow.python.training import training as train
@ -222,6 +224,7 @@ _allowed_symbols.extend([
'app',
'bitwise',
'compat',
'data',
'distributions',
'errors',
'estimator',
@ -231,12 +234,15 @@ _allowed_symbols.extend([
'graph_util',
'image',
'initializers',
'keras',
'layers',
'linalg',
'logging',
'losses',
'metrics',
'newaxis',
'nn',
'profiler',
'python_io',
'resource_loader',
'saved_model',
@ -247,9 +253,6 @@ _allowed_symbols.extend([
'test',
'train',
'user_ops',
'layers',
'profiler',
'keras',
])
# Variables framework.versions:
@ -263,11 +266,11 @@ _allowed_symbols.extend([
# referenced in the whitelist.
remove_undocumented(__name__, _allowed_symbols, [
framework_lib, array_ops, check_ops, client_lib, compat, constant_op,
control_flow_ops, confusion_matrix_m, distributions,
functional_ops, histogram_ops, io_ops,
losses, math_ops, metrics, nn, resource_loader, sets, script_ops,
control_flow_ops, confusion_matrix_m, data, distributions,
functional_ops, histogram_ops, io_ops, keras, layers,
losses, math_ops, metrics, nn, profiler, resource_loader, sets, script_ops,
session_ops, sparse_ops, state_ops, string_ops, summary, tensor_array_ops,
train, layers, profiler, keras
train
])
# Special dunders that we choose to export:

View File

@ -0,0 +1,14 @@
path: "tensorflow.data.Dataset.__metaclass__"
tf_class {
is_instance: "<class \'abc.ABCMeta\'>"
member_method {
name: "__init__"
}
member_method {
name: "mro"
}
member_method {
name: "register"
argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,113 @@
path: "tensorflow.data.Dataset"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
is_instance: "<type \'object\'>"
member {
name: "output_shapes"
mtype: "<class \'abc.abstractproperty\'>"
}
member {
name: "output_types"
mtype: "<class \'abc.abstractproperty\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "filter"
argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "flat_map"
argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"
argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_tensor_slices"
argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_tensors"
argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "interleave"
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
}
member_method {
name: "list_files"
argspec: "args=[\'file_pattern\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_initializable_iterator"
argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "make_one_shot_iterator"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "map"
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "padded_batch"
argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "prefetch"
argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "shard"
argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "skip"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,14 @@
path: "tensorflow.data.FixedLengthRecordDataset.__metaclass__"
tf_class {
is_instance: "<class \'abc.ABCMeta\'>"
member_method {
name: "__init__"
}
member_method {
name: "mro"
}
member_method {
name: "register"
argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,114 @@
path: "tensorflow.data.FixedLengthRecordDataset"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.readers.FixedLengthRecordDataset\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
is_instance: "<type \'object\'>"
member {
name: "output_shapes"
mtype: "<type \'property\'>"
}
member {
name: "output_types"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filenames\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'buffer_size\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "filter"
argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "flat_map"
argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"
argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_tensor_slices"
argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_tensors"
argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "interleave"
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
}
member_method {
name: "list_files"
argspec: "args=[\'file_pattern\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_initializable_iterator"
argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "make_one_shot_iterator"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "map"
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "padded_batch"
argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "prefetch"
argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "shard"
argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "skip"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,41 @@
path: "tensorflow.data.Iterator"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
is_instance: "<type \'object\'>"
member {
name: "initializer"
mtype: "<type \'property\'>"
}
member {
name: "output_shapes"
mtype: "<type \'property\'>"
}
member {
name: "output_types"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'iterator_resource\', \'initializer\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_string_handle"
argspec: "args=[\'string_handle\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_structure"
argspec: "args=[\'output_types\', \'output_shapes\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "get_next"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "make_initializer"
argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "string_handle"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -0,0 +1,14 @@
path: "tensorflow.data.TFRecordDataset.__metaclass__"
tf_class {
is_instance: "<class \'abc.ABCMeta\'>"
member_method {
name: "__init__"
}
member_method {
name: "mro"
}
member_method {
name: "register"
argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,114 @@
path: "tensorflow.data.TFRecordDataset"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.readers.TFRecordDataset\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
is_instance: "<type \'object\'>"
member {
name: "output_shapes"
mtype: "<type \'property\'>"
}
member {
name: "output_types"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "filter"
argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "flat_map"
argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"
argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_tensor_slices"
argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_tensors"
argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "interleave"
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
}
member_method {
name: "list_files"
argspec: "args=[\'file_pattern\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_initializable_iterator"
argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "make_one_shot_iterator"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "map"
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "padded_batch"
argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "prefetch"
argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "shard"
argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "skip"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,14 @@
path: "tensorflow.data.TextLineDataset.__metaclass__"
tf_class {
is_instance: "<class \'abc.ABCMeta\'>"
member_method {
name: "__init__"
}
member_method {
name: "mro"
}
member_method {
name: "register"
argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,114 @@
path: "tensorflow.data.TextLineDataset"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.readers.TextLineDataset\'>"
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
is_instance: "<type \'object\'>"
member {
name: "output_shapes"
mtype: "<type \'property\'>"
}
member {
name: "output_types"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "batch"
argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "cache"
argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
}
member_method {
name: "concatenate"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "filter"
argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "flat_map"
argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"
argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_tensor_slices"
argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_tensors"
argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "interleave"
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
}
member_method {
name: "list_files"
argspec: "args=[\'file_pattern\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_initializable_iterator"
argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "make_one_shot_iterator"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "map"
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "padded_batch"
argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "prefetch"
argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "shard"
argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "skip"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "take"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,23 @@
path: "tensorflow.data"
tf_module {
member {
name: "Dataset"
mtype: "<class \'abc.ABCMeta\'>"
}
member {
name: "FixedLengthRecordDataset"
mtype: "<class \'abc.ABCMeta\'>"
}
member {
name: "Iterator"
mtype: "<type \'type\'>"
}
member {
name: "TFRecordDataset"
mtype: "<class \'abc.ABCMeta\'>"
}
member {
name: "TextLineDataset"
mtype: "<class \'abc.ABCMeta\'>"
}
}

View File

@ -292,6 +292,10 @@ tf_module {
name: "contrib"
mtype: "<class \'tensorflow.python.util.lazy_loader.LazyLoader\'>"
}
member {
name: "data"
mtype: "<type \'module\'>"
}
member {
name: "distributions"
mtype: "<type \'module\'>"