mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Graduate tf.data.experimental.dense_to_ragged_batch to tf.data.Dataset.ragged_batch.
PiperOrigin-RevId: 482614367
This commit is contained in:
parent
503a749f6c
commit
6beecdaf55
|
|
@ -142,6 +142,10 @@ This release contains contributions from many people at Google, as well as:
|
|||
the [RFC](https://github.com/tensorflow/community/pull/415) for more
|
||||
details regarding its design and properties.
|
||||
|
||||
* `tf.data`:
|
||||
* Graduated experimental APIs:
|
||||
* [`tf.data.Dataset.ragged_batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset/#ragged_batch), which batches elements of `tf.data.Dataset`s into `tf.RaggedTensor`s.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* `tf.image`
|
||||
|
|
|
|||
|
|
@ -147,22 +147,6 @@ tf_py_test(
|
|||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "dense_to_ragged_batch_test",
|
||||
size = "small",
|
||||
srcs = ["dense_to_ragged_batch_test.py"],
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "dense_to_sparse_batch_test",
|
||||
size = "medium",
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Batching dataset transformations."""
|
||||
from tensorflow.python.data.ops import batch_op
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import structured_function
|
||||
from tensorflow.python.data.util import convert
|
||||
|
|
@ -22,15 +21,14 @@ from tensorflow.python.framework import dtypes
|
|||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("data.experimental.dense_to_ragged_batch")
|
||||
@deprecation.deprecated(None, "Use `tf.data.Dataset.ragged_batch` instead.")
|
||||
def dense_to_ragged_batch(batch_size,
|
||||
drop_remainder=False,
|
||||
row_splits_dtype=dtypes.int64):
|
||||
|
|
@ -87,11 +85,8 @@ def dense_to_ragged_batch(batch_size,
|
|||
Returns:
|
||||
Dataset: A `Dataset`.
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
ragged_dataset = _DenseToRaggedDataset(dataset, row_splits_dtype)
|
||||
return batch_op.BatchDataset(
|
||||
ragged_dataset, batch_size=batch_size, drop_remainder=drop_remainder)
|
||||
return dataset.ragged_batch(batch_size, drop_remainder, row_splits_dtype)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
|
|
@ -381,78 +376,3 @@ class _MapAndBatchDataset(dataset_ops.UnaryDataset):
|
|||
@property
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
class _DenseToRaggedDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that encodes dense inputs as ragged (w/ ragged_rank=0).
|
||||
|
||||
In particular:
|
||||
|
||||
* Any tf.Tensor elements with rank>0 are encoded as ragged tensors with
|
||||
ragged_rank=0. This allows tensors with varying shape to be batched
|
||||
together.
|
||||
* Any other elements are left as-is.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, row_splits_dtype):
|
||||
"""Constructs a new _DenseToRaggedDataset.
|
||||
|
||||
Args:
|
||||
input_dataset: The dataset whose tf.Tensor elements should be made ragged.
|
||||
row_splits_dtype: The dtype that should be used for the `row_splits` of
|
||||
any new ragged tensors. Existing `tf.RaggedTensor` elements do *not*
|
||||
have their row_splits dtype changed.
|
||||
"""
|
||||
# Replace each TensorSpec in the input dataset's structure with a
|
||||
# corresponding RaggedTensorSpec.
|
||||
def to_ragged_spec(spec):
|
||||
"""Returns the new spec based on RaggedTensors."""
|
||||
if (not isinstance(spec, tensor_spec.TensorSpec) or
|
||||
spec.shape.rank is None or
|
||||
spec.shape.is_fully_defined()):
|
||||
return spec
|
||||
else:
|
||||
ragged_rank = max([
|
||||
axis for (axis, size) in enumerate(spec.shape.as_list())
|
||||
if size is None
|
||||
])
|
||||
return ragged_tensor.RaggedTensorSpec(
|
||||
shape=spec.shape,
|
||||
dtype=spec.dtype,
|
||||
ragged_rank=ragged_rank,
|
||||
row_splits_dtype=row_splits_dtype)
|
||||
|
||||
self._structure = nest.map_structure(to_ragged_spec,
|
||||
input_dataset.element_spec)
|
||||
|
||||
# Replace each tf.Tensor value in the input dataset with a variant-encoded
|
||||
# RaggedTensor. Since we're updating the corresponding structure to be
|
||||
# a RaggedTensorSpec, this variant-encoded tensor will be decoded with
|
||||
# RaggedTensorSpec._from_tensor_list.
|
||||
def to_ragged_variant(value):
|
||||
"""Re-encode Tensors as RaggedTensors."""
|
||||
if (not isinstance(value, ops.Tensor) or
|
||||
value.shape.rank is None or
|
||||
value.shape.is_fully_defined()):
|
||||
return value
|
||||
else:
|
||||
spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
|
||||
if spec._ragged_rank > 0: # pylint: disable=protected-access
|
||||
value = ragged_tensor.RaggedTensor.from_tensor(
|
||||
value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access
|
||||
return spec._to_tensor_list(value)[0] # pylint: disable=protected-access
|
||||
|
||||
# Tuples are automatically unpacked by `dataset.map` so we repack them.
|
||||
if structured_function._should_unpack(input_dataset.element_spec): # pylint: disable=protected-access
|
||||
map_fn = lambda *value: nest.map_structure(to_ragged_variant, value)
|
||||
else:
|
||||
map_fn = lambda value: nest.map_structure(to_ragged_variant, value)
|
||||
|
||||
self._mapped_dataset = input_dataset.map(map_fn)
|
||||
|
||||
variant = self._mapped_dataset._variant_tensor # pylint: disable=protected-access
|
||||
super(_DenseToRaggedDataset, self).__init__(input_dataset, variant)
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
|
|
|||
|
|
@ -754,6 +754,22 @@ tf_py_test(
|
|||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "ragged_batch_test",
|
||||
size = "small",
|
||||
srcs = ["ragged_batch_test.py"],
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "random_test",
|
||||
size = "small",
|
||||
|
|
|
|||
|
|
@ -12,11 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.dense_to_ragged_batch`."""
|
||||
"""Tests for `tf.data.Dataset.ragged_batch`."""
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
|
|
@ -118,8 +117,7 @@ class RaggedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
for _ in range(nrows)]
|
||||
|
||||
# Batch the dataset, and check that batches match slices from `rows`.
|
||||
batched_dataset = dataset.apply(
|
||||
batching.dense_to_ragged_batch(batch_size, drop_remainder))
|
||||
batched_dataset = dataset.ragged_batch(batch_size, drop_remainder)
|
||||
get_next = self.getNext(batched_dataset)
|
||||
for start_row in range(0, nrows, batch_size):
|
||||
end_row = start_row + batch_size
|
||||
|
|
@ -155,7 +153,7 @@ class RaggedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(np.arange(nrows))
|
||||
dataset = dataset.map(make_structure)
|
||||
dataset = dataset.apply(batching.dense_to_ragged_batch(batch_size))
|
||||
dataset = dataset.ragged_batch(batch_size)
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
for i in range(0, nrows, batch_size):
|
||||
|
|
@ -53,6 +53,7 @@ py_library(
|
|||
":iterator_ops",
|
||||
":load_ops",
|
||||
":options",
|
||||
":ragged_batch_op",
|
||||
":rebatch_op",
|
||||
":save_ops",
|
||||
":structured_function",
|
||||
|
|
@ -137,6 +138,23 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ragged_batch_op",
|
||||
srcs = ["ragged_batch_op.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:experimental_dataset_ops_gen",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/util:convert",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/data/util:structure",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "save_ops",
|
||||
srcs = ["save_op.py"],
|
||||
|
|
|
|||
|
|
@ -2126,6 +2126,67 @@ class DatasetV2(
|
|||
drop_remainder,
|
||||
name=name)
|
||||
|
||||
def ragged_batch(self,
|
||||
batch_size,
|
||||
drop_remainder=False,
|
||||
row_splits_dtype=dtypes.int64,
|
||||
name=None):
|
||||
"""Combines consecutive elements of this dataset into `tf.RaggedTensor`s.
|
||||
|
||||
Like `tf.data.Dataset.batch`, the components of the resulting element will
|
||||
have an additional outer dimension, which will be `batch_size` (or
|
||||
`N % batch_size` for the last element if `batch_size` does not divide the
|
||||
number of input elements `N` evenly and `drop_remainder` is `False`). If
|
||||
your program depends on the batches having the same outer dimension, you
|
||||
should set the `drop_remainder` argument to `True` to prevent the smaller
|
||||
batch from being produced.
|
||||
|
||||
Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
|
||||
different shapes:
|
||||
|
||||
* If an input element is a `tf.Tensor` whose static `tf.TensorShape` is
|
||||
fully defined, then it is batched as normal.
|
||||
* If an input element is a `tf.Tensor` whose static `tf.TensorShape`
|
||||
contains one or more axes with unknown size (i.e., `shape[i]=None`), then
|
||||
the output will contain a `tf.RaggedTensor` that is ragged up to any of such
|
||||
dimensions.
|
||||
* If an input element is a `tf.RaggedTensor` or any other type, then it is
|
||||
batched as normal.
|
||||
|
||||
Example:
|
||||
|
||||
>>> dataset = tf.data.Dataset.range(6)
|
||||
>>> dataset = dataset.map(lambda x: tf.range(x))
|
||||
>>> dataset.element_spec.shape
|
||||
TensorShape([None])
|
||||
>>> dataset = dataset.ragged_batch(2)
|
||||
>>> for batch in dataset:
|
||||
... print(batch)
|
||||
<tf.RaggedTensor [[], [0]]>
|
||||
<tf.RaggedTensor [[0, 1], [0, 1, 2]]>
|
||||
<tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]>
|
||||
|
||||
Args:
|
||||
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
|
||||
consecutive elements of this dataset to combine in a single batch.
|
||||
drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
|
||||
whether the last batch should be dropped in the case it has fewer than
|
||||
`batch_size` elements; the default behavior is not to drop the smaller
|
||||
batch.
|
||||
row_splits_dtype: The dtype that should be used for the `row_splits` of
|
||||
any new ragged tensors. Existing `tf.RaggedTensor` elements do not have
|
||||
their row_splits dtype changed.
|
||||
name: (Optional.) A string indicating a name for the `tf.data` operation.
|
||||
|
||||
Returns:
|
||||
A new `Dataset` with the transformation applied as described above.
|
||||
"""
|
||||
# Loaded lazily due to a circular dependency (dataset_ops ->
|
||||
# ragged_batch_op -> dataset_ops).
|
||||
from tensorflow.python.data.ops import ragged_batch_op # pylint: disable=g-import-not-at-top
|
||||
return ragged_batch_op.ragged_batch(self, batch_size, drop_remainder,
|
||||
row_splits_dtype, name)
|
||||
|
||||
def map(self,
|
||||
map_func,
|
||||
num_parallel_calls=None,
|
||||
|
|
|
|||
109
tensorflow/python/data/ops/ragged_batch_op.py
Normal file
109
tensorflow/python/data/ops/ragged_batch_op.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The implementation of `tf.data.Dataset.ragged_batch`."""
|
||||
from tensorflow.python.data.ops import batch_op
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import structured_function
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
|
||||
|
||||
def ragged_batch(self,
|
||||
batch_size,
|
||||
drop_remainder=False,
|
||||
row_splits_dtype=dtypes.int64,
|
||||
name=None):
|
||||
ragged_dataset = _DenseToRaggedDataset(self, row_splits_dtype, name)
|
||||
return batch_op.BatchDataset(
|
||||
ragged_dataset, batch_size=batch_size, drop_remainder=drop_remainder)
|
||||
|
||||
|
||||
class _DenseToRaggedDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that encodes dense inputs as ragged (w/ ragged_rank=0).
|
||||
|
||||
In particular:
|
||||
|
||||
* Any tf.Tensor elements with rank>0 are encoded as ragged tensors with
|
||||
ragged_rank=0. This allows tensors with varying shape to be batched
|
||||
together.
|
||||
* Any other elements are left as-is.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, row_splits_dtype, name=None):
|
||||
"""Constructs a new _DenseToRaggedDataset.
|
||||
|
||||
Args:
|
||||
input_dataset: The dataset whose tf.Tensor elements should be made ragged.
|
||||
row_splits_dtype: The dtype that should be used for the `row_splits` of
|
||||
any new ragged tensors. Existing `tf.RaggedTensor` elements do *not*
|
||||
have their row_splits dtype changed.
|
||||
name: (Optional.) A string indicating a name for the `tf.data` operation.
|
||||
"""
|
||||
# Replace each TensorSpec in the input dataset's structure with a
|
||||
# corresponding RaggedTensorSpec.
|
||||
def to_ragged_spec(spec):
|
||||
"""Returns the new spec based on RaggedTensors."""
|
||||
if (not isinstance(spec, tensor_spec.TensorSpec) or
|
||||
spec.shape.rank is None or
|
||||
spec.shape.is_fully_defined()):
|
||||
return spec
|
||||
else:
|
||||
ragged_rank = max([
|
||||
axis for (axis, size) in enumerate(spec.shape.as_list())
|
||||
if size is None
|
||||
])
|
||||
return ragged_tensor.RaggedTensorSpec(
|
||||
shape=spec.shape,
|
||||
dtype=spec.dtype,
|
||||
ragged_rank=ragged_rank,
|
||||
row_splits_dtype=row_splits_dtype)
|
||||
|
||||
self._structure = nest.map_structure(to_ragged_spec,
|
||||
input_dataset.element_spec)
|
||||
|
||||
# Replace each tf.Tensor value in the input dataset with a variant-encoded
|
||||
# RaggedTensor. Since we're updating the corresponding structure to be
|
||||
# a RaggedTensorSpec, this variant-encoded tensor will be decoded with
|
||||
# RaggedTensorSpec._from_tensor_list.
|
||||
def to_ragged_variant(value):
|
||||
"""Re-encode Tensors as RaggedTensors."""
|
||||
if (not isinstance(value, ops.Tensor) or
|
||||
value.shape.rank is None or
|
||||
value.shape.is_fully_defined()):
|
||||
return value
|
||||
else:
|
||||
spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
|
||||
if spec._ragged_rank > 0: # pylint: disable=protected-access
|
||||
value = ragged_tensor.RaggedTensor.from_tensor(
|
||||
value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access
|
||||
return spec._to_tensor_list(value)[0] # pylint: disable=protected-access
|
||||
|
||||
# Tuples are automatically unpacked by `dataset.map` so we repack them.
|
||||
if structured_function._should_unpack(input_dataset.element_spec): # pylint: disable=protected-access
|
||||
map_fn = lambda *value: nest.map_structure(to_ragged_variant, value)
|
||||
else:
|
||||
map_fn = lambda value: nest.map_structure(to_ragged_variant, value)
|
||||
|
||||
self._mapped_dataset = input_dataset.map(map_fn)
|
||||
self._name = name
|
||||
variant = self._mapped_dataset._variant_tensor # pylint: disable=protected-access
|
||||
super(_DenseToRaggedDataset, self).__init__(input_dataset, variant)
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
|
@ -143,6 +143,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -145,6 +145,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -145,6 +145,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -145,6 +145,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -145,6 +145,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -145,6 +145,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -145,6 +145,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -110,6 +110,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -112,6 +112,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -111,6 +111,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -112,6 +112,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -112,6 +112,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -113,6 +113,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -112,6 +112,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -113,6 +113,10 @@ tf_class {
|
|||
name: "prefetch"
|
||||
argspec: "args=[\'self\', \'buffer_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ragged_batch"
|
||||
argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\', \'row_splits_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user