mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
eager: Documentation and example models.
- Updated README - A preliminary "User's Guide" - A few example models, some with benchmarks PiperOrigin-RevId: 173996303
This commit is contained in:
parent
cd81bc8e09
commit
a6a6188439
|
|
@ -18,7 +18,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
|
@ -26,12 +25,8 @@ from tensorflow.python.layers import base as base_layer
|
|||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
_cudnn_rnn_ops_so = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so"))
|
||||
|
||||
CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION
|
||||
CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION
|
||||
CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM
|
||||
|
|
|
|||
15
tensorflow/contrib/eager/README.OPENSOURCE.md
Normal file
15
tensorflow/contrib/eager/README.OPENSOURCE.md
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
TensorFlow has many kernels for doing (deep) learning and data manipulation.
|
||||
There are typically assembled into computational graphs which can run
|
||||
efficiently in a variety of environments.
|
||||
|
||||
We are exploring an alternative interaction, where kernels are invoked
|
||||
immediately and call this "eager execution". We are hoping to retain the
|
||||
benefits of graphs while improving usability with benefits like:
|
||||
|
||||
- Immediate error messages and easier debugging
|
||||
- Flexibility to use Python datastructures and control flow
|
||||
- Reduced boilerplate
|
||||
|
||||
Eager execution is under active development.
|
||||
There are not many developer-facing materials yet, but stay tuned for updates
|
||||
in this directory.
|
||||
|
|
@ -1,15 +1,78 @@
|
|||
TensorFlow has many kernels for doing (deep) learning and data manipulation.
|
||||
There are typically assembled into computational graphs which can run
|
||||
efficiently in a variety of environments.
|
||||
# TensorFlow Eager Execution
|
||||
|
||||
We are exploring an alternative interaction, where kernels are invoked
|
||||
immediately and call this "eager execution". We are hoping to retain the
|
||||
benefits of graphs while improving usability with benefits like:
|
||||
> *WARNING*: This is a preview/pre-alpha version. The API and performance
|
||||
> characteristics are subject to change.
|
||||
|
||||
- Immediate error messages and easier debugging
|
||||
- Flexibility to use Python datastructures and control flow
|
||||
- Reduced boilerplate
|
||||
Eager execution is an experimental interface to TensorFlow that provides an
|
||||
imperative programming style (à la [NumPy](http://www.numpy.org)). When you
|
||||
enable eager execution, TensorFlow operations execute immediately; you do not
|
||||
execute a pre-constructed graph with
|
||||
[`Session.run()`](https://www.tensorflow.org/api_docs/python/tf/Session).
|
||||
|
||||
Eager execution is under active development.
|
||||
There are not many developer-facing materials yet, but stay tuned for updates
|
||||
in this directory.
|
||||
For example, consider a simple computation in TensorFlow:
|
||||
|
||||
```python
|
||||
x = tf.placeholder(tf.float32, shape=[1, 1])
|
||||
m = tf.matmul(x, x)
|
||||
|
||||
with tf.Session() as sess:
|
||||
print(sess.run(m, feed_dict={x: [[2.]]}))
|
||||
|
||||
# Will print [[4.]]
|
||||
```
|
||||
|
||||
Eager execution makes this much simpler:
|
||||
|
||||
```python
|
||||
x = [[2.]]
|
||||
m = tf.matmul(x, x)
|
||||
|
||||
print(m)
|
||||
```
|
||||
|
||||
## Caveats
|
||||
|
||||
This feature is in early stages and work remains to be done in terms of smooth
|
||||
support for distributed and multi-GPU training and CPU performance.
|
||||
|
||||
- [Known issues](https://github.com/tensorflow/tensorflow/issues?q=is%3Aissue%20is%3Aopen%20label%3Aproj%3Aeager)
|
||||
- Feedback is welcome, please consider
|
||||
[filing an issue](https://github.com/tensorflow/tensorflow/issues/new) to provide it.
|
||||
|
||||
## Installation
|
||||
|
||||
Since eager execution is not yet part of a TensorFlow release, using it requires
|
||||
either [building from source](https://www.tensorflow.org/install/install_sources)
|
||||
or the latest nightly builds. The nightly builds are available as:
|
||||
|
||||
- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and
|
||||
|
||||
- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images.
|
||||
|
||||
For example, to run the latest nightly docker image:
|
||||
|
||||
```sh
|
||||
# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker
|
||||
nvidia-docker pull tensorflow/tensorflow:nightly-gpu
|
||||
nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu
|
||||
|
||||
# If you do not have a GPU, use the CPU-only image
|
||||
docker pull tensorflow/tensorflow:nightly
|
||||
docker run -it -p 8888:8888 tensorflow/tensorflow:nightly
|
||||
```
|
||||
|
||||
And then visit http://localhost:8888 in your browser for a Jupyter notebook
|
||||
environment. Try out the notebooks below.
|
||||
|
||||
## Documentation
|
||||
|
||||
For an introduction to eager execution in TensorFlow, see:
|
||||
|
||||
- [User Guide](python/g3doc/guide.md)
|
||||
- Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb)
|
||||
- Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb)
|
||||
- Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb)
|
||||
|
||||
## Changelog
|
||||
|
||||
- 2017/10/31: Initial preview release.
|
||||
|
|
|
|||
15
tensorflow/contrib/eager/python/examples/BUILD
Normal file
15
tensorflow/contrib/eager/python/examples/BUILD
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# TensorFlow code for training gradient boosted trees.
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
py_library(
|
||||
name = "examples_pip",
|
||||
deps = [
|
||||
"//tensorflow/contrib/eager/python/examples/linear_regression",
|
||||
"//tensorflow/contrib/eager/python/examples/mnist",
|
||||
"//tensorflow/contrib/eager/python/examples/resnet50",
|
||||
"//tensorflow/contrib/eager/python/examples/rnn_colorbot",
|
||||
"//tensorflow/contrib/eager/python/examples/rnn_ptb",
|
||||
],
|
||||
)
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_binary(
|
||||
name = "linear_regression",
|
||||
srcs = ["linear_regression.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "linear_regression_test",
|
||||
size = "small",
|
||||
srcs = ["linear_regression_test.py"],
|
||||
additional_deps = [
|
||||
":linear_regression",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""TensorFlow Eager Execution Example: Linear Regression.
|
||||
|
||||
This example shows how to use TensorFlow Eager Execution to fit a simple linear
|
||||
regression model using some synthesized data. Specifically, it illustrates how
|
||||
to define the forward path of the linear model and the loss function, as well
|
||||
as how to obtain the gradients of the loss function with respect to the
|
||||
variables and update the variables with the gradients.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import tensorflow.contrib.eager as tfe
|
||||
|
||||
|
||||
class LinearModel(tfe.Network):
|
||||
"""A TensorFlow linear regression model.
|
||||
|
||||
Uses TensorFlow's eager execution.
|
||||
|
||||
For those familiar with TensorFlow graphs, notice the absence of
|
||||
`tf.Session`. The `forward()` method here immediately executes and
|
||||
returns output values. The `loss()` method immediately compares the
|
||||
output of `forward()` with the target adn returns the MSE loss value.
|
||||
The `fit()` performs gradient-descent training on the model's weights
|
||||
and bias.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Constructs a LinearModel object."""
|
||||
super(LinearModel, self).__init__()
|
||||
self._hidden_layer = self.track_layer(tf.layers.Dense(1))
|
||||
|
||||
def call(self, xs):
|
||||
"""Invoke the linear model.
|
||||
|
||||
Args:
|
||||
xs: input features, as a tensor of size [batch_size, ndims].
|
||||
|
||||
Returns:
|
||||
ys: the predictions of the linear mode, as a tensor of size [batch_size]
|
||||
"""
|
||||
return self._hidden_layer(xs)
|
||||
|
||||
|
||||
def fit(model, dataset, optimizer, verbose=False, logdir=None):
|
||||
"""Fit the linear-regression model.
|
||||
|
||||
Args:
|
||||
model: The LinearModel to fit.
|
||||
dataset: The tf.data.Dataset to use for training data.
|
||||
optimizer: The TensorFlow Optimizer object to be used.
|
||||
verbose: If true, will print out loss values at every iteration.
|
||||
logdir: The directory in which summaries will be written for TensorBoard
|
||||
(optional).
|
||||
"""
|
||||
|
||||
# The loss function to optimize.
|
||||
def mean_square_loss(xs, ys):
|
||||
return tf.reduce_mean(tf.square(model(xs) - ys))
|
||||
|
||||
loss_and_grads = tfe.implicit_value_and_gradients(mean_square_loss)
|
||||
|
||||
tf.train.get_or_create_global_step()
|
||||
if logdir:
|
||||
# Support for TensorBoard summaries. Once training has started, use:
|
||||
# tensorboard --logdir=<logdir>
|
||||
summary_writer = tf.contrib.summary.create_summary_file_writer(logdir)
|
||||
|
||||
# Training loop.
|
||||
for i, (xs, ys) in enumerate(tfe.Iterator(dataset)):
|
||||
loss, grads = loss_and_grads(xs, ys)
|
||||
if verbose:
|
||||
print("Iteration %d: loss = %s" % (i, loss.numpy()))
|
||||
|
||||
optimizer.apply_gradients(grads, global_step=tf.train.get_global_step())
|
||||
|
||||
if logdir:
|
||||
with summary_writer.as_default():
|
||||
with tf.contrib.summary.always_record_summaries():
|
||||
tf.contrib.summary.scalar("loss", loss)
|
||||
|
||||
|
||||
def synthetic_dataset(w, b, noise_level, batch_size, num_batches):
|
||||
"""tf.data.Dataset that yields synthetic data for linear regression."""
|
||||
|
||||
# w is a matrix with shape [N, M]
|
||||
# b is a vector with shape [M]
|
||||
# So:
|
||||
# - Generate x's as vectors with shape [batch_size N]
|
||||
# - y = tf.matmul(x, W) + b + noise
|
||||
def batch(_):
|
||||
x = tf.random_normal([batch_size, tf.shape(w)[0]])
|
||||
y = tf.matmul(x, w) + b + noise_level * tf.random_normal([])
|
||||
return x, y
|
||||
|
||||
with tf.device("/device:CPU:0"):
|
||||
return tf.data.Dataset.range(num_batches).map(batch)
|
||||
|
||||
|
||||
def main(_):
|
||||
tfe.enable_eager_execution()
|
||||
# Ground-truth constants.
|
||||
true_w = [[-2.0], [4.0], [1.0]]
|
||||
true_b = [0.5]
|
||||
noise_level = 0.01
|
||||
|
||||
# Training constants.
|
||||
batch_size = 64
|
||||
learning_rate = 0.1
|
||||
|
||||
print("True w: %s" % true_w)
|
||||
print("True b: %s\n" % true_b)
|
||||
|
||||
model = LinearModel()
|
||||
dataset = synthetic_dataset(true_w, true_b, noise_level, batch_size, 20)
|
||||
|
||||
device = "gpu:0" if tfe.num_gpus() else "cpu:0"
|
||||
print("Using device: %s" % device)
|
||||
with tf.device(device):
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
|
||||
fit(model, dataset, optimizer, verbose=True, logdir=FLAGS.logdir)
|
||||
|
||||
print("\nAfter training: w = %s" % model.variables[0].numpy())
|
||||
print("\nAfter training: b = %s" % model.variables[1].numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--logdir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="logdir in which TensorBoard summaries will be written (optional).")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Unit tests for linear regression example under TensorFlow eager execution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import tensorflow.contrib.eager as tfe
|
||||
from tensorflow.contrib.eager.python.examples.linear_regression import linear_regression
|
||||
|
||||
|
||||
def device():
|
||||
return "/device:GPU:0" if tfe.num_gpus() > 0 else "/device:CPU:0"
|
||||
|
||||
|
||||
class LinearRegressionTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(LinearRegressionTest, self).setUp()
|
||||
self._tmp_logdir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self._tmp_logdir)
|
||||
super(LinearRegressionTest, self).tearDown()
|
||||
|
||||
def testSyntheticDataset(self):
|
||||
true_w = tf.random_uniform([3, 1])
|
||||
true_b = [1.0]
|
||||
batch_size = 10
|
||||
num_batches = 2
|
||||
noise_level = 0.
|
||||
dataset = linear_regression.synthetic_dataset(true_w, true_b, noise_level,
|
||||
batch_size, num_batches)
|
||||
|
||||
it = tfe.Iterator(dataset)
|
||||
for _ in range(2):
|
||||
(xs, ys) = it.next()
|
||||
self.assertEqual((batch_size, 3), xs.shape)
|
||||
self.assertEqual((batch_size, 1), ys.shape)
|
||||
self.assertEqual(tf.float32, xs.dtype)
|
||||
self.assertEqual(tf.float32, ys.dtype)
|
||||
with self.assertRaises(StopIteration):
|
||||
it.next()
|
||||
|
||||
def testLinearRegression(self):
|
||||
true_w = [[1.0], [-0.5], [2.0]]
|
||||
true_b = [1.0]
|
||||
|
||||
model = linear_regression.LinearModel()
|
||||
dataset = linear_regression.synthetic_dataset(
|
||||
true_w, true_b, noise_level=0., batch_size=64, num_batches=40)
|
||||
|
||||
with tf.device(device()):
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
|
||||
linear_regression.fit(model, dataset, optimizer, logdir=self._tmp_logdir)
|
||||
|
||||
self.assertAllClose(true_w, model.variables[0].numpy(), rtol=1e-2)
|
||||
self.assertAllClose(true_b, model.variables[1].numpy(), rtol=1e-2)
|
||||
self.assertTrue(glob.glob(os.path.join(self._tmp_logdir, "events.out.*")))
|
||||
|
||||
|
||||
class EagerLinearRegressionBenchmark(tf.test.Benchmark):
|
||||
|
||||
def benchmarkEagerLinearRegression(self):
|
||||
num_batches = 200
|
||||
batch_size = 64
|
||||
dataset = linear_regression.synthetic_dataset(
|
||||
w=tf.random_uniform([3, 1]),
|
||||
b=tf.random_uniform([1]),
|
||||
noise_level=0.01,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches)
|
||||
burn_in_dataset = dataset.take(10)
|
||||
|
||||
model = linear_regression.LinearModel()
|
||||
|
||||
with tf.device(device()):
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
|
||||
|
||||
# Perform burn-in.
|
||||
linear_regression.fit(model, burn_in_dataset, optimizer)
|
||||
|
||||
start_time = time.time()
|
||||
linear_regression.fit(model, dataset, optimizer)
|
||||
wall_time = time.time() - start_time
|
||||
|
||||
examples_per_sec = num_batches * batch_size / wall_time
|
||||
self.report_benchmark(
|
||||
name="eager_train_%s" %
|
||||
("gpu" if tfe.num_gpus() > 0 else "cpu"),
|
||||
iters=num_batches,
|
||||
extras={"examples_per_sec": examples_per_sec},
|
||||
wall_time=wall_time)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tfe.enable_eager_execution()
|
||||
tf.test.main()
|
||||
36
tensorflow/contrib/eager/python/examples/mnist/BUILD
Normal file
36
tensorflow/contrib/eager/python/examples/mnist/BUILD
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_binary(
|
||||
name = "mnist",
|
||||
srcs = ["mnist.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
"//tensorflow/examples/tutorials/mnist:input_data",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "mnist_test",
|
||||
srcs = ["mnist_test.py"],
|
||||
additional_deps = [
|
||||
":mnist",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "mnist_graph_test",
|
||||
srcs = ["mnist_graph_test.py"],
|
||||
additional_deps = [
|
||||
":mnist",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
10
tensorflow/contrib/eager/python/examples/mnist/README.md
Normal file
10
tensorflow/contrib/eager/python/examples/mnist/README.md
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
Classification model for the MNIST dataset using eager execution.
|
||||
|
||||
To run:
|
||||
|
||||
```
|
||||
python mnist.py
|
||||
```
|
||||
|
||||
`mnist_graph_test.py` demonstrates that the same code that is executed eagerly
|
||||
in `mnist.py` is used to construct a TensorFlow graph.
|
||||
270
tensorflow/contrib/eager/python/examples/mnist/mnist.py
Normal file
270
tensorflow/contrib/eager/python/examples/mnist/mnist.py
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A deep MNIST classifier using convolutional layers.
|
||||
|
||||
Sample usage:
|
||||
python mnist.py --help
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import tensorflow.contrib.eager as tfe
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
class MNISTModel(tfe.Network):
|
||||
"""MNIST Network.
|
||||
|
||||
Network structure is equivalent to:
|
||||
https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/examples/tutorials/mnist/mnist_deep.py
|
||||
and
|
||||
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
|
||||
|
||||
But written using the tf.layers API.
|
||||
"""
|
||||
|
||||
def __init__(self, data_format):
|
||||
"""Creates a model for classifying a hand-written digit.
|
||||
|
||||
Args:
|
||||
data_format: Either 'channels_first' or 'channels_last'.
|
||||
'channels_first' is typically faster on GPUs while 'channels_last' is
|
||||
typically faster on CPUs. See
|
||||
https://www.tensorflow.org/performance/performance_guide#data_formats
|
||||
"""
|
||||
super(MNISTModel, self).__init__(name='')
|
||||
if data_format == 'channels_first':
|
||||
self._input_shape = [-1, 1, 28, 28]
|
||||
else:
|
||||
assert data_format == 'channels_last'
|
||||
self._input_shape = [-1, 28, 28, 1]
|
||||
self.conv1 = self.track_layer(
|
||||
tf.layers.Conv2D(32, 5, data_format=data_format, activation=tf.nn.relu))
|
||||
self.conv2 = self.track_layer(
|
||||
tf.layers.Conv2D(64, 5, data_format=data_format, activation=tf.nn.relu))
|
||||
self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.nn.relu))
|
||||
self.fc2 = self.track_layer(tf.layers.Dense(10))
|
||||
self.dropout = self.track_layer(tf.layers.Dropout(0.5))
|
||||
self.max_pool2d = self.track_layer(
|
||||
tf.layers.MaxPooling2D(
|
||||
(2, 2), (2, 2), padding='SAME', data_format=data_format))
|
||||
|
||||
def call(self, inputs, training):
|
||||
"""Computes labels from inputs.
|
||||
|
||||
Users should invoke __call__ to run the network, which delegates to this
|
||||
method (and not call this method directly).
|
||||
|
||||
Args:
|
||||
inputs: A batch of images as a Tensor with shape [batch_size, 784].
|
||||
training: True if invoked in the context of training (causing dropout to
|
||||
be applied). False otherwise.
|
||||
|
||||
Returns:
|
||||
A Tensor with shape [batch_size, 10] containing the predicted logits
|
||||
for each image in the batch, for each of the 10 classes.
|
||||
"""
|
||||
|
||||
x = tf.reshape(inputs, self._input_shape)
|
||||
x = self.conv1(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = tf.layers.flatten(x)
|
||||
x = self.fc1(x)
|
||||
if training:
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
def loss(predictions, labels):
|
||||
return tf.reduce_mean(
|
||||
tf.nn.softmax_cross_entropy_with_logits(
|
||||
logits=predictions, labels=labels))
|
||||
|
||||
|
||||
def compute_accuracy(predictions, labels):
|
||||
return tf.reduce_sum(
|
||||
tf.cast(
|
||||
tf.equal(
|
||||
tf.argmax(predictions, axis=1,
|
||||
output_type=tf.int64),
|
||||
tf.argmax(labels, axis=1,
|
||||
output_type=tf.int64)),
|
||||
dtype=tf.float32)) / float(predictions.shape[0].value)
|
||||
|
||||
|
||||
def train_one_epoch(model, optimizer, dataset, log_interval=None):
|
||||
"""Trains model on `dataset` using `optimizer`."""
|
||||
|
||||
tf.train.get_or_create_global_step()
|
||||
|
||||
def model_loss(labels, images):
|
||||
prediction = model(images, training=True)
|
||||
loss_value = loss(prediction, labels)
|
||||
tf.contrib.summary.scalar('loss', loss_value)
|
||||
tf.contrib.summary.scalar('accuracy',
|
||||
compute_accuracy(prediction, labels))
|
||||
return loss_value
|
||||
|
||||
for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)):
|
||||
with tf.contrib.summary.record_summaries_every_n_global_steps(10):
|
||||
batch_model_loss = functools.partial(model_loss, labels, images)
|
||||
optimizer.minimize(
|
||||
batch_model_loss, global_step=tf.train.get_global_step())
|
||||
if log_interval and batch % log_interval == 0:
|
||||
print('Batch #%d\tLoss: %.6f' % (batch, batch_model_loss()))
|
||||
|
||||
|
||||
def test(model, dataset):
|
||||
"""Perform an evaluation of `model` on the examples from `dataset`."""
|
||||
avg_loss = tfe.metrics.Mean('loss')
|
||||
accuracy = tfe.metrics.Accuracy('accuracy')
|
||||
|
||||
for (images, labels) in tfe.Iterator(dataset):
|
||||
predictions = model(images, training=False)
|
||||
avg_loss(loss(predictions, labels))
|
||||
accuracy(tf.argmax(predictions, axis=1, output_type=tf.int64),
|
||||
tf.argmax(labels, axis=1, output_type=tf.int64))
|
||||
print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' %
|
||||
(avg_loss.result(), 100 * accuracy.result()))
|
||||
with tf.contrib.summary.always_record_summaries():
|
||||
tf.contrib.summary.scalar('loss', avg_loss.result())
|
||||
tf.contrib.summary.scalar('accuracy', accuracy.result())
|
||||
|
||||
|
||||
def load_data(data_dir):
|
||||
"""Returns training and test tf.data.Dataset objects."""
|
||||
data = input_data.read_data_sets(data_dir, one_hot=True)
|
||||
train_ds = tf.data.Dataset.from_tensor_slices((data.train.images,
|
||||
data.train.labels))
|
||||
test_ds = tf.data.Dataset.from_tensors((data.test.images, data.test.labels))
|
||||
return (train_ds, test_ds)
|
||||
|
||||
|
||||
def main(_):
|
||||
tfe.enable_eager_execution()
|
||||
|
||||
(device, data_format) = ('/gpu:0', 'channels_first')
|
||||
if FLAGS.no_gpu or tfe.num_gpus() <= 0:
|
||||
(device, data_format) = ('/cpu:0', 'channels_last')
|
||||
print('Using device %s, and data format %s.' % (device, data_format))
|
||||
|
||||
# Load the datasets
|
||||
(train_ds, test_ds) = load_data(FLAGS.data_dir)
|
||||
train_ds = train_ds.shuffle(60000).batch(FLAGS.batch_size)
|
||||
|
||||
# Create the model and optimizer
|
||||
model = MNISTModel(data_format)
|
||||
optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum)
|
||||
|
||||
if FLAGS.output_dir:
|
||||
train_dir = os.path.join(FLAGS.output_dir, 'train')
|
||||
test_dir = os.path.join(FLAGS.output_dir, 'eval')
|
||||
tf.gfile.MakeDirs(FLAGS.output_dir)
|
||||
else:
|
||||
train_dir = None
|
||||
test_dir = None
|
||||
summary_writer = tf.contrib.summary.create_summary_file_writer(
|
||||
train_dir, flush_secs=10)
|
||||
test_summary_writer = tf.contrib.summary.create_summary_file_writer(
|
||||
test_dir, flush_secs=10, name='test')
|
||||
checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
|
||||
|
||||
with tf.device(device):
|
||||
for epoch in range(1, 11):
|
||||
with tfe.restore_variables_on_create(
|
||||
tf.train.latest_checkpoint(FLAGS.checkpoint_dir)):
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
start = time.time()
|
||||
with summary_writer.as_default():
|
||||
train_one_epoch(model, optimizer, train_ds, FLAGS.log_interval)
|
||||
end = time.time()
|
||||
print('\nTrain time for epoch #%d (global step %d): %f' % (
|
||||
epoch, global_step.numpy(), end - start))
|
||||
with test_summary_writer.as_default():
|
||||
test(model, test_ds)
|
||||
all_variables = (
|
||||
model.variables
|
||||
+ tfe.get_optimizer_variables(optimizer)
|
||||
+ [global_step])
|
||||
tfe.Saver(all_variables).save(
|
||||
checkpoint_prefix, global_step=global_step)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--data-dir',
|
||||
type=str,
|
||||
default='/tmp/tensorflow/mnist/input_data',
|
||||
help='Directory for storing input data')
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=64,
|
||||
metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument(
|
||||
'--log-interval',
|
||||
type=int,
|
||||
default=10,
|
||||
metavar='N',
|
||||
help='how many batches to wait before logging training status')
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
metavar='N',
|
||||
help='Directory to write TensorBoard summaries')
|
||||
parser.add_argument(
|
||||
'--checkpoint_dir',
|
||||
type=str,
|
||||
default='/tmp/tensorflow/mnist/checkpoints/',
|
||||
metavar='N',
|
||||
help='Directory to save checkpoints in (once per epoch)')
|
||||
parser.add_argument(
|
||||
'--lr',
|
||||
type=float,
|
||||
default=0.01,
|
||||
metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument(
|
||||
'--momentum',
|
||||
type=float,
|
||||
default=0.5,
|
||||
metavar='M',
|
||||
help='SGD momentum (default: 0.5)')
|
||||
parser.add_argument(
|
||||
'--no-gpu',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='disables GPU usage even if a GPU is available')
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.eager.python.examples.mnist import mnist
|
||||
|
||||
|
||||
def data_format():
|
||||
return "channels_first" if tf.test.is_gpu_available() else "channels_last"
|
||||
|
||||
|
||||
class MNISTGraphTest(tf.test.TestCase):
|
||||
|
||||
def testTrainGraph(self):
|
||||
# The MNISTModel class can be executed eagerly (as in mnist.py and
|
||||
# mnist_test.py) and also be used to construct a TensorFlow graph, which is
|
||||
# then trained in a session.
|
||||
with tf.Graph().as_default():
|
||||
# Generate some random data.
|
||||
batch_size = 64
|
||||
images = np.random.randn(batch_size, 784).astype(np.float32)
|
||||
digits = np.random.randint(low=0, high=10, size=batch_size)
|
||||
labels = np.zeros((batch_size, 10))
|
||||
labels[np.arange(batch_size), digits] = 1.
|
||||
|
||||
# Create a model, optimizer, and dataset as would be done
|
||||
# for eager execution as well.
|
||||
model = mnist.MNISTModel(data_format())
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
dataset = tf.data.Dataset.from_tensors((images, labels))
|
||||
|
||||
# Define the loss tensor (as opposed to a loss function when
|
||||
# using eager execution).
|
||||
(images, labels) = dataset.make_one_shot_iterator().get_next()
|
||||
predictions = model(images, training=True)
|
||||
loss = mnist.loss(predictions, labels)
|
||||
|
||||
train_op = optimizer.minimize(loss)
|
||||
init = tf.global_variables_initializer()
|
||||
with tf.Session() as sess:
|
||||
# Variables have to be initialized in the session.
|
||||
sess.run(init)
|
||||
# Train using the optimizer.
|
||||
sess.run(train_op)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
62
tensorflow/contrib/eager/python/examples/mnist/mnist_test.py
Normal file
62
tensorflow/contrib/eager/python/examples/mnist/mnist_test.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import tensorflow.contrib.eager as tfe
|
||||
from tensorflow.contrib.eager.python.examples.mnist import mnist
|
||||
|
||||
|
||||
def device():
|
||||
return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0"
|
||||
|
||||
|
||||
def data_format():
|
||||
return "channels_first" if tfe.num_gpus() else "channels_last"
|
||||
|
||||
|
||||
def random_dataset():
|
||||
batch_size = 64
|
||||
images = tf.random_normal([batch_size, 784])
|
||||
digits = tf.random_uniform([batch_size], minval=0, maxval=10, dtype=tf.int32)
|
||||
labels = tf.one_hot(digits, 10)
|
||||
return tf.data.Dataset.from_tensors((images, labels))
|
||||
|
||||
|
||||
class MNISTTest(tf.test.TestCase):
|
||||
|
||||
def testTrainOneEpoch(self):
|
||||
model = mnist.MNISTModel(data_format())
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
|
||||
dataset = random_dataset()
|
||||
with tf.device(device()):
|
||||
tf.train.get_or_create_global_step()
|
||||
mnist.train_one_epoch(model, optimizer, dataset)
|
||||
|
||||
def testTest(self):
|
||||
model = mnist.MNISTModel(data_format())
|
||||
dataset = random_dataset()
|
||||
with tf.device(device()):
|
||||
tf.train.get_or_create_global_step()
|
||||
mnist.test(model, dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tfe.enable_eager_execution()
|
||||
tf.test.main()
|
||||
|
|
@ -0,0 +1,529 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "U9i2Dsh-ziXr"
|
||||
},
|
||||
"source": [
|
||||
"# Eager Execution Tutorial: Basics\n",
|
||||
"\n",
|
||||
"This notebook introduces the basics of using TensorFlow's eager execution capabilities. It covers concepts such as:\n",
|
||||
"\n",
|
||||
"* Importing required packages\n",
|
||||
"* Enabling eager execution\n",
|
||||
"* Creating and using TensorFlow Tensors and Variables\n",
|
||||
"* Using TensorFlow interactively\n",
|
||||
"* Using GPUs with eager execution enabled\n",
|
||||
"\n",
|
||||
"This notebook does *not* cover modeling topics, such as gradients."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "z1JcS5iBXMRO"
|
||||
},
|
||||
"source": [
|
||||
"# Step 1: Import Eager\n",
|
||||
"\n",
|
||||
"The key imports for eager execution are the following:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "RlIWhyeLoYnG"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Import TensorFlow.\n",
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"# Import TensorFlow eager execution support (subject to future changes).\n",
|
||||
"import tensorflow.contrib.eager as tfe"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "H9UySOPLXdaw"
|
||||
},
|
||||
"source": [
|
||||
"# Step 2: Enable eager execution\n",
|
||||
"\n",
|
||||
"All future TensorFlow calls will execute the\n",
|
||||
"underlying TensorFlow ops immediately:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "WPTUfGq6kJ5w"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tfe.enable_eager_execution()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "twBfWd5xyu_d"
|
||||
},
|
||||
"source": [
|
||||
"# Step 3: Interactively Use TensorFlow!\n",
|
||||
"\n",
|
||||
"Now you can call TensorFlow functions and get results, immediately! No more `tf.Sessions`!\n",
|
||||
"\n",
|
||||
"TensorFlow will automatically wrap native Python types for you with operator overloading for TensorFlow Tensors."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "ngUe237Wt48W"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(tf.add(1, 2))\n",
|
||||
"print(tf.add([1, 2], [3, 4]))\n",
|
||||
"print(tf.square(5))\n",
|
||||
"print(tf.reduce_sum([1, 2, 3]))\n",
|
||||
"print(tf.encode_base64(\"hello world\"))\n",
|
||||
"print(\"\")\n",
|
||||
"\n",
|
||||
"x = tf.constant(2)\n",
|
||||
"y = tf.constant(3)\n",
|
||||
"print(x * y + 1)\n",
|
||||
"\n",
|
||||
"# Most TensorFlow ops are directly usable with eager execution, giving\n",
|
||||
"# results immediately.\n",
|
||||
"print(tf.contrib.signal.hamming_window(x * y + 1))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "IDY4WsYRhP81"
|
||||
},
|
||||
"source": [
|
||||
"Numpy arrays are supported, too:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "lCUWzso6mbqR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"ones = np.ones([3, 3])\n",
|
||||
"\n",
|
||||
"print(\"numpy 3x3 matrix of 1s:\")\n",
|
||||
"print(ones)\n",
|
||||
"print(\"\")\n",
|
||||
"\n",
|
||||
"print(\"Multiplied by 42:\")\n",
|
||||
"print(tf.multiply(ones, 42))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "PBNP8yTRfu_X"
|
||||
},
|
||||
"source": [
|
||||
"# Step 4: Define and Print TensorFlow Variables\n",
|
||||
"\n",
|
||||
"To define TensorFlow variables, use the `get_variable()` function as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "3Twf_Rw-gQFM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = tf.get_variable(name=\"x\", shape=[], dtype=tf.float32, initializer=tf.zeros_initializer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "45G7094TxsMb"
|
||||
},
|
||||
"source": [
|
||||
"## Printing TensorFlow Variables"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "UJBJeZ5XxuwA"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This does NOT print the Variable's actual value:\n",
|
||||
"print(\"Printing a TensorFlow Variable:\")\n",
|
||||
"print(x)\n",
|
||||
"print(\"\")\n",
|
||||
"\n",
|
||||
"# A TensorFlow variable represents a reference to a tensor.\n",
|
||||
"# The `read_value()` method provides access to the current value of the\n",
|
||||
"# variable. Tensorflow Variables are automatically initialized according to the\n",
|
||||
"# semantics defined in tf.get_variable().\n",
|
||||
"print(\"Printing a TensorFlow Variable's value using .read_value():\")\n",
|
||||
"print(x.read_value())\n",
|
||||
"print(\"\")\n",
|
||||
"\n",
|
||||
"print(\"Printing a TensorFlow Variable's value using .read_value().numpy():\")\n",
|
||||
"print(x.read_value().numpy())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "2njjWHcTpBEn"
|
||||
},
|
||||
"source": [
|
||||
"## Changing a TensorFlow Variable's value\n",
|
||||
"\n",
|
||||
"To change a TensorFlow Variable's value, use its `.assign()` or `.assign_add()` method:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "v3wr6Erbo_hB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x.assign(42)\n",
|
||||
"print(x.read_value())\n",
|
||||
"\n",
|
||||
"x.assign_add(3)\n",
|
||||
"print(x.read_value())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "uhtynjHVpTB5"
|
||||
},
|
||||
"source": [
|
||||
"## Use a Variable just like any other Tensor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "7PbktdnHoehR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(x + 3)\n",
|
||||
"\n",
|
||||
"# This code will broadcast the value across the list of numbers:\n",
|
||||
"print(x * [1, 2, 4])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "GVChqwlwy1SI"
|
||||
},
|
||||
"source": [
|
||||
"# Step 5: Debug Errors with Instant Feedback\n",
|
||||
"\n",
|
||||
"TensorFlow's eager execution helps you identify and debug runtime issues through interactive exploration of code snippets.\n",
|
||||
"\n",
|
||||
"Below, we'll define a length-4 vector, and attempt two `tf.slice()` operations,\n",
|
||||
"one being legal and the other being illegal, leading to a runtime error that is\n",
|
||||
"raised immediately."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "23ap04N0v4k0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vector = tf.constant([10.0, 20.0, 30.0, 40.0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "FCUMsIYxxRRa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Works, because the values of `begin` and `size` (the 2nd and 3rd input\n",
|
||||
"# arguments) are within the bound of `vector`.\n",
|
||||
"print(tf.slice(vector, [1], [3]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "T8me2oCNxpFp"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The following does NOT work, because the value of `size` (the 3rd\n",
|
||||
"# argument) causes the indices to go out of the bounds of `vector`. The\n",
|
||||
"# error is raised immediately.\n",
|
||||
"try:\n",
|
||||
" print(tf.slice(vector, [1], [4]))\n",
|
||||
"except tf.OpError as e:\n",
|
||||
" print(\"Caught error: %s\" % e)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "irxJhAgar84v"
|
||||
},
|
||||
"source": [
|
||||
"# Step 6: Using the GPU\n",
|
||||
"\n",
|
||||
"You can place Tensors on the GPU by calling a Tensor's `.gpu()` method.\n",
|
||||
"\n",
|
||||
"The first operation executing on the GPU may be slow as TensorFlow initializes. Subsequent uses will be much faster."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "7J4N9baqaKCL"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The example code from here on will work only if your notebook\n",
|
||||
"# is running on a machine with a functional CUDA GPU. The following\n",
|
||||
"# line checks that.\n",
|
||||
"is_gpu_available = tfe.num_gpus() \u003e 0\n",
|
||||
"\n",
|
||||
"# Create some Tensors\n",
|
||||
"SIZE = 1000\n",
|
||||
"cpu_tensor = tf.random_normal([SIZE, SIZE])\n",
|
||||
"\n",
|
||||
"if is_gpu_available:\n",
|
||||
" gpu_tensor = cpu_tensor.gpu()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "4E-2n7VbzY1n"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Time a CPU-based matrix multiplication\n",
|
||||
"\n",
|
||||
"print(\"Time to conduct matmul on CPU:\")\n",
|
||||
"%time tf.matmul(cpu_tensor, cpu_tensor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "vbSFW-T5zhZF"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Time GPU-based matrix multiplications.\n",
|
||||
"\n",
|
||||
"if is_gpu_available:\n",
|
||||
" # First use of the GPU will be slow:\n",
|
||||
" print(\"Time to conduct first matmul on GPU:\")\n",
|
||||
" %time tf.matmul(gpu_tensor, gpu_tensor)\n",
|
||||
" print()\n",
|
||||
"\n",
|
||||
" # Subsequent uses are much faster:\n",
|
||||
" print(\"Time to conduct second matmul on GPU:\")\n",
|
||||
" %time tf.matmul(gpu_tensor, gpu_tensor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "E5pIOe3Rz7iW"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Second timing demo for GPUs, after it has been used once:\n",
|
||||
"\n",
|
||||
"cpu_tensor = tf.random_normal([SIZE, SIZE])\n",
|
||||
"print(\"Time to conduct CPU matmul:\")\n",
|
||||
"%time tf.matmul(cpu_tensor, cpu_tensor)\n",
|
||||
"print()\n",
|
||||
"\n",
|
||||
"if is_gpu_available:\n",
|
||||
" gpu_tensor = cpu_tensor.gpu()\n",
|
||||
" print(\"Time to conduct GPU matmul:\")\n",
|
||||
" %time tf.matmul(gpu_tensor, gpu_tensor)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"default_view": {},
|
||||
"name": "Eager Execution Tutorial: Basics",
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "0B0kLcpwLFwKEVm9XNkFueGk4bTg",
|
||||
"timestamp": 1504118841551
|
||||
}
|
||||
],
|
||||
"version": "0.3.2",
|
||||
"views": {}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,218 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "U9i2Dsh-ziXr"
|
||||
},
|
||||
"source": [
|
||||
"# Eager Execution Tutorial: Importing Data\n",
|
||||
"\n",
|
||||
"This notebook demonstrates the use of the [`tf.contrib.data.Dataset` API](https://www.tensorflow.org/programmers_guide/datasets) to build pipelines to feed data to your program. It covers:\n",
|
||||
"\n",
|
||||
"* Creating a `Dataset`.\n",
|
||||
"* Iteration over a `Dataset` with eager execution enabled.\n",
|
||||
"\n",
|
||||
"We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n",
|
||||
"\n",
|
||||
"If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly different. You will use a Pythonic `Iterator()` class instead of using `make_one_shot_iterator()` and `get_next()`. As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "z1JcS5iBXMRO"
|
||||
},
|
||||
"source": [
|
||||
"# Setup: Enable eager execution\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "RlIWhyeLoYnG"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Import TensorFlow.\n",
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"# Import TensorFlow eager execution support (subject to future changes).\n",
|
||||
"import tensorflow.contrib.eager as tfe\n",
|
||||
"\n",
|
||||
"# Enable eager execution\n",
|
||||
"tfe.enable_eager_execution()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "H9UySOPLXdaw"
|
||||
},
|
||||
"source": [
|
||||
"# Step 1: Create a source `Dataset`\n",
|
||||
"\n",
|
||||
"Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/TFRecordDataset). See the [Programmer's Guide](https://www.google.com/url?sa=D\u0026q=https%3A%2F%2Fwww.tensorflow.org%2Fprogrammers_guide%2Fdatasets%23reading_input_data) for more information."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "WPTUfGq6kJ5w"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds_tensors = tf.contrib.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n",
|
||||
"\n",
|
||||
"# Create a CSV file\n",
|
||||
"import tempfile\n",
|
||||
"_, filename = tempfile.mkstemp()\n",
|
||||
"with open(filename, 'w') as f:\n",
|
||||
" f.write(\"\"\"Line 1\n",
|
||||
"Line 2\n",
|
||||
"Line 3\n",
|
||||
" \"\"\")\n",
|
||||
"ds_file = tf.contrib.data.TextLineDataset(filename)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "twBfWd5xyu_d"
|
||||
},
|
||||
"source": [
|
||||
"# Step 2: Apply transformations\n",
|
||||
"\n",
|
||||
"Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.contrib.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset) for details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"cellView": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "ngUe237Wt48W"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n",
|
||||
"ds_file = ds_file.batch(2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "IDY4WsYRhP81"
|
||||
},
|
||||
"source": [
|
||||
"# Step 3: Iterate\n",
|
||||
"\n",
|
||||
"Use `tfe.Iterator` on the `Dataset` object to get a Python iterator over the contents of the dataset.\n",
|
||||
"\n",
|
||||
"If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that this process of iteration is different. Here there are no calls to `Dataset.make_one_shot_iterator()` and no `get_next()` calls."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
},
|
||||
"height": 153,
|
||||
"output_extras": [
|
||||
{
|
||||
"item_id": 1
|
||||
}
|
||||
]
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 201,
|
||||
"status": "ok",
|
||||
"timestamp": 1505952405928,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": 420
|
||||
},
|
||||
"id": "lCUWzso6mbqR",
|
||||
"outputId": "ec027d30-96c6-4ea4-9ee1-ef74ec1ae29a"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Elements of ds_tensors:\n",
|
||||
"tf.Tensor([4 9], shape=(2,), dtype=int32)\n",
|
||||
"tf.Tensor([16 25], shape=(2,), dtype=int32)\n",
|
||||
"tf.Tensor([36 1], shape=(2,), dtype=int32)\n",
|
||||
"\n",
|
||||
"Elements in ds_file:\n",
|
||||
"tf.Tensor(['Line 1' 'Line 2'], shape=(2,), dtype=string)\n",
|
||||
"tf.Tensor(['Line 3' ' '], shape=(2,), dtype=string)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print('Elements of ds_tensors:')\n",
|
||||
"for x in tfe.Iterator(ds_tensors):\n",
|
||||
" print(x)\n",
|
||||
"\n",
|
||||
"print('\\nElements in ds_file:')\n",
|
||||
"for x in tfe.Iterator(ds_file):\n",
|
||||
" print(x)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"default_view": {},
|
||||
"last_runtime": {
|
||||
"build_target": "",
|
||||
"kind": "local"
|
||||
},
|
||||
"name": "Eager Execution Tutorial: Importing Data",
|
||||
"provenance": [],
|
||||
"version": "0.3.2",
|
||||
"views": {}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
43
tensorflow/contrib/eager/python/examples/resnet50/BUILD
Normal file
43
tensorflow/contrib/eager/python/examples/resnet50/BUILD
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_library(
|
||||
name = "resnet50",
|
||||
srcs = ["resnet50.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "resnet50_test",
|
||||
size = "large",
|
||||
srcs = ["resnet50_test.py"],
|
||||
additional_deps = [
|
||||
":resnet50",
|
||||
"//tensorflow/contrib/summary:summary_test_util",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "resnet50_graph_test",
|
||||
size = "large",
|
||||
srcs = ["resnet50_graph_test.py"],
|
||||
additional_deps = [
|
||||
":resnet50",
|
||||
"//tensorflow/contrib/summary:summary_test_util",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
tags = [
|
||||
"noasan",
|
||||
"nomsan",
|
||||
],
|
||||
)
|
||||
34
tensorflow/contrib/eager/python/examples/resnet50/README.md
Normal file
34
tensorflow/contrib/eager/python/examples/resnet50/README.md
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
Image classification using the ResNet50 model described in
|
||||
[Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385).
|
||||
|
||||
Contents:
|
||||
|
||||
- `resnet50.py`: Model definition
|
||||
- `resnet50_test.py`: Sanity unittests and benchmarks for using the model with
|
||||
eager execution enabled.
|
||||
- `resnet50_graph_test.py`: Sanity unittests and benchmarks when using the same
|
||||
model code to construct a TensorFlow graph.
|
||||
|
||||
# Benchmarks
|
||||
|
||||
Using a synthetic data.
|
||||
|
||||
```
|
||||
# Using eager execution
|
||||
bazel run -c opt --config=cuda :resnet50_test -- --benchmarks=.
|
||||
|
||||
# Using graph execution
|
||||
bazel run -c opt --config=cuda :resnet50_graph_test -- --benchmarks=.
|
||||
```
|
||||
|
||||
(Or remove the `--config=cuda` flag for running on CPU instead of GPU).
|
||||
|
||||
On October 31, 2017, the benchmarks demostrated comparable performance
|
||||
for eager and graph execution of this particular model when using
|
||||
a single NVIDIA Titan X (Pascal) GPU on a host with an
|
||||
Intel Xeon E5-1650 CPU @ 3.50GHz and a batch size of 32.
|
||||
|
||||
| Benchmark name | batch size | images/second |
|
||||
| --------------------------------------- | ------------- | ------------- |
|
||||
| eager_train_gpu_batch_32_channels_first | 32 | 171 |
|
||||
| graph_train_gpu_batch_32_channels_first | 32 | 172 |
|
||||
324
tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
Normal file
324
tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ResNet50 model definition compatible with TensorFlow's eager execution.
|
||||
|
||||
Reference [Deep Residual Learning for Image
|
||||
Recognition](https://arxiv.org/abs/1512.03385)
|
||||
|
||||
Adapted from tf.keras.applications.ResNet50. A notable difference is that the
|
||||
model here outputs logits while the Keras model outputs probability.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.eager as tfe
|
||||
|
||||
|
||||
class _IdentityBlock(tfe.Network):
|
||||
"""_IdentityBlock is the block that has no conv layer at shortcut.
|
||||
|
||||
Args:
|
||||
kernel_size: the kernel size of middle conv layer at main path
|
||||
filters: list of integers, the filters of 3 conv layer at main path
|
||||
stage: integer, current stage label, used for generating layer names
|
||||
block: 'a','b'..., current block label, used for generating layer names
|
||||
data_format: data_format for the input ('channels_first' or
|
||||
'channels_last').
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size, filters, stage, block, data_format):
|
||||
super(_IdentityBlock, self).__init__(name='')
|
||||
filters1, filters2, filters3 = filters
|
||||
|
||||
conv_name_base = 'res' + str(stage) + block + '_branch'
|
||||
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
||||
bn_axis = 1 if data_format == 'channels_first' else 3
|
||||
|
||||
self.conv2a = self.track_layer(
|
||||
tf.layers.Conv2D(
|
||||
filters1, (1, 1),
|
||||
name=conv_name_base + '2a',
|
||||
data_format=data_format))
|
||||
self.bn2a = self.track_layer(
|
||||
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a'))
|
||||
|
||||
self.conv2b = self.track_layer(
|
||||
tf.layers.Conv2D(
|
||||
filters2,
|
||||
kernel_size,
|
||||
padding='same',
|
||||
data_format=data_format,
|
||||
name=conv_name_base + '2b'))
|
||||
self.bn2b = self.track_layer(
|
||||
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b'))
|
||||
|
||||
self.conv2c = self.track_layer(
|
||||
tf.layers.Conv2D(
|
||||
filters3, (1, 1),
|
||||
name=conv_name_base + '2c',
|
||||
data_format=data_format))
|
||||
self.bn2c = self.track_layer(
|
||||
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c'))
|
||||
|
||||
def call(self, input_tensor, training=False):
|
||||
x = self.conv2a(input_tensor)
|
||||
x = self.bn2a(x, training=training)
|
||||
x = tf.nn.relu(x)
|
||||
|
||||
x = self.conv2b(x)
|
||||
x = self.bn2b(x, training=training)
|
||||
x = tf.nn.relu(x)
|
||||
|
||||
x = self.conv2c(x)
|
||||
x = self.bn2c(x, training=training)
|
||||
|
||||
x += input_tensor
|
||||
return tf.nn.relu(x)
|
||||
|
||||
|
||||
class _ConvBlock(tfe.Network):
|
||||
"""_ConvBlock is the block that has a conv layer at shortcut.
|
||||
|
||||
Args:
|
||||
kernel_size: the kernel size of middle conv layer at main path
|
||||
filters: list of integers, the filterss of 3 conv layer at main path
|
||||
stage: integer, current stage label, used for generating layer names
|
||||
block: 'a','b'..., current block label, used for generating layer names
|
||||
data_format: data_format for the input ('channels_first' or
|
||||
'channels_last').
|
||||
strides: strides for the convolution. Note that from stage 3, the first
|
||||
conv layer at main path is with strides=(2,2), and the shortcut should
|
||||
have strides=(2,2) as well.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
filters,
|
||||
stage,
|
||||
block,
|
||||
data_format,
|
||||
strides=(2, 2)):
|
||||
super(_ConvBlock, self).__init__(name='')
|
||||
filters1, filters2, filters3 = filters
|
||||
|
||||
conv_name_base = 'res' + str(stage) + block + '_branch'
|
||||
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
||||
bn_axis = 1 if data_format == 'channels_first' else 3
|
||||
|
||||
self.conv2a = self.track_layer(
|
||||
tf.layers.Conv2D(
|
||||
filters1, (1, 1),
|
||||
strides=strides,
|
||||
name=conv_name_base + '2a',
|
||||
data_format=data_format))
|
||||
self.bn2a = self.track_layer(
|
||||
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a'))
|
||||
|
||||
self.conv2b = self.track_layer(
|
||||
tf.layers.Conv2D(
|
||||
filters2,
|
||||
kernel_size,
|
||||
padding='same',
|
||||
name=conv_name_base + '2b',
|
||||
data_format=data_format))
|
||||
self.bn2b = self.track_layer(
|
||||
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b'))
|
||||
|
||||
self.conv2c = self.track_layer(
|
||||
tf.layers.Conv2D(
|
||||
filters3, (1, 1),
|
||||
name=conv_name_base + '2c',
|
||||
data_format=data_format))
|
||||
self.bn2c = self.track_layer(
|
||||
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c'))
|
||||
|
||||
self.conv_shortcut = self.track_layer(
|
||||
tf.layers.Conv2D(
|
||||
filters3, (1, 1),
|
||||
strides=strides,
|
||||
name=conv_name_base + '1',
|
||||
data_format=data_format))
|
||||
self.bn_shortcut = self.track_layer(
|
||||
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '1'))
|
||||
|
||||
def call(self, input_tensor, training=False):
|
||||
x = self.conv2a(input_tensor)
|
||||
x = self.bn2a(x, training=training)
|
||||
x = tf.nn.relu(x)
|
||||
|
||||
x = self.conv2b(x)
|
||||
x = self.bn2b(x, training=training)
|
||||
x = tf.nn.relu(x)
|
||||
|
||||
x = self.conv2c(x)
|
||||
x = self.bn2c(x, training=training)
|
||||
|
||||
shortcut = self.conv_shortcut(input_tensor)
|
||||
shortcut = self.bn_shortcut(shortcut, training=training)
|
||||
|
||||
x += shortcut
|
||||
return tf.nn.relu(x)
|
||||
|
||||
|
||||
class ResNet50(tfe.Network):
|
||||
"""Instantiates the ResNet50 architecture.
|
||||
|
||||
Args:
|
||||
data_format: format for the image. Either 'channels_first' or
|
||||
'channels_last'. 'channels_first' is typically faster on GPUs while
|
||||
'channels_last' is typically faster on CPUs. See
|
||||
https://www.tensorflow.org/performance/performance_guide#data_formats
|
||||
name: Prefix applied to names of variables created in the model.
|
||||
trainable: Is the model trainable? If true, performs backward
|
||||
and optimization after call() method.
|
||||
include_top: whether to include the fully-connected layer at the top of the
|
||||
network.
|
||||
pooling: Optional pooling mode for feature extraction when `include_top`
|
||||
is `False`.
|
||||
- `None` means that the output of the model will be the 4D tensor
|
||||
output of the last convolutional layer.
|
||||
- `avg` means that global average pooling will be applied to the output of
|
||||
the last convolutional layer, and thus the output of the model will be
|
||||
a 2D tensor.
|
||||
- `max` means that global max pooling will be applied.
|
||||
classes: optional number of classes to classify images into, only to be
|
||||
specified if `include_top` is True.
|
||||
|
||||
Raises:
|
||||
ValueError: in case of invalid argument for data_format.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_format,
|
||||
name=None,
|
||||
trainable=True,
|
||||
include_top=True,
|
||||
pooling=None,
|
||||
classes=1000):
|
||||
super(ResNet50, self).__init__(name='')
|
||||
|
||||
valid_channel_values = ('channels_first', 'channels_last')
|
||||
if data_format not in valid_channel_values:
|
||||
raise ValueError('Unknown data_format: %s. Valid values: %s' %
|
||||
(data_format, valid_channel_values))
|
||||
self.include_top = include_top
|
||||
|
||||
def conv_block(filters, stage, block, strides=(2, 2)):
|
||||
l = _ConvBlock(
|
||||
3,
|
||||
filters,
|
||||
stage=stage,
|
||||
block=block,
|
||||
data_format=data_format,
|
||||
strides=strides)
|
||||
return self.track_layer(l)
|
||||
|
||||
def id_block(filters, stage, block):
|
||||
l = _IdentityBlock(
|
||||
3, filters, stage=stage, block=block, data_format=data_format)
|
||||
return self.track_layer(l)
|
||||
|
||||
self.conv1 = self.track_layer(
|
||||
tf.layers.Conv2D(
|
||||
64, (7, 7),
|
||||
strides=(2, 2),
|
||||
data_format=data_format,
|
||||
padding='same',
|
||||
name='conv1'))
|
||||
bn_axis = 1 if data_format == 'channels_first' else 3
|
||||
self.bn_conv1 = self.track_layer(
|
||||
tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1'))
|
||||
self.max_pool = self.track_layer(
|
||||
tf.layers.MaxPooling2D((3, 3), strides=(2, 2), data_format=data_format))
|
||||
|
||||
self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1))
|
||||
self.l2b = id_block([64, 64, 256], stage=2, block='b')
|
||||
self.l2c = id_block([64, 64, 256], stage=2, block='c')
|
||||
|
||||
self.l3a = conv_block([128, 128, 512], stage=3, block='a')
|
||||
self.l3b = id_block([128, 128, 512], stage=3, block='b')
|
||||
self.l3c = id_block([128, 128, 512], stage=3, block='c')
|
||||
self.l3d = id_block([128, 128, 512], stage=3, block='d')
|
||||
|
||||
self.l4a = conv_block([256, 256, 1024], stage=4, block='a')
|
||||
self.l4b = id_block([256, 256, 1024], stage=4, block='b')
|
||||
self.l4c = id_block([256, 256, 1024], stage=4, block='c')
|
||||
self.l4d = id_block([256, 256, 1024], stage=4, block='d')
|
||||
self.l4e = id_block([256, 256, 1024], stage=4, block='e')
|
||||
self.l4f = id_block([256, 256, 1024], stage=4, block='f')
|
||||
|
||||
self.l5a = conv_block([512, 512, 2048], stage=5, block='a')
|
||||
self.l5b = id_block([512, 512, 2048], stage=5, block='b')
|
||||
self.l5c = id_block([512, 512, 2048], stage=5, block='c')
|
||||
|
||||
self.avg_pool = self.track_layer(
|
||||
tf.layers.AveragePooling2D(
|
||||
(7, 7), strides=(7, 7), data_format=data_format))
|
||||
|
||||
if self.include_top:
|
||||
self.fc1000 = self.track_layer(
|
||||
tf.layers.Dense(classes, name='fc1000'))
|
||||
else:
|
||||
reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3]
|
||||
reduction_indices = tf.constant(reduction_indices)
|
||||
if pooling == 'avg':
|
||||
self.global_pooling = functools.partial(
|
||||
tf.reduce_mean,
|
||||
reduction_indices=reduction_indices,
|
||||
keep_dims=False)
|
||||
elif pooling == 'max':
|
||||
self.global_pooling = functools.partial(
|
||||
tf.reduce_max, reduction_indices=reduction_indices, keep_dims=False)
|
||||
else:
|
||||
self.global_pooling = None
|
||||
|
||||
def call(self, input_tensor, training=False):
|
||||
x = self.conv1(input_tensor)
|
||||
x = self.bn_conv1(x, training=training)
|
||||
x = tf.nn.relu(x)
|
||||
x = self.max_pool(x)
|
||||
|
||||
x = self.l2a(x, training=training)
|
||||
x = self.l2b(x, training=training)
|
||||
x = self.l2c(x, training=training)
|
||||
|
||||
x = self.l3a(x, training=training)
|
||||
x = self.l3b(x, training=training)
|
||||
x = self.l3c(x, training=training)
|
||||
x = self.l3d(x, training=training)
|
||||
|
||||
x = self.l4a(x, training=training)
|
||||
x = self.l4b(x, training=training)
|
||||
x = self.l4c(x, training=training)
|
||||
x = self.l4d(x, training=training)
|
||||
x = self.l4e(x, training=training)
|
||||
x = self.l4f(x, training=training)
|
||||
|
||||
x = self.l5a(x, training=training)
|
||||
x = self.l5b(x, training=training)
|
||||
x = self.l5c(x, training=training)
|
||||
|
||||
x = self.avg_pool(x)
|
||||
|
||||
if self.include_top:
|
||||
return self.fc1000(tf.layers.flatten(x))
|
||||
elif self.global_pooling:
|
||||
return self.global_pooling(x)
|
||||
else:
|
||||
return x
|
||||
|
|
@ -0,0 +1,163 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests and benchmarks for ResNet50 under graph execution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.eager.python.examples.resnet50 import resnet50
|
||||
from tensorflow.contrib.summary import summary_test_util
|
||||
|
||||
|
||||
def data_format():
|
||||
return 'channels_first' if tf.test.is_gpu_available() else 'channels_last'
|
||||
|
||||
|
||||
def image_shape(batch_size):
|
||||
if data_format() == 'channels_first':
|
||||
return [batch_size, 3, 224, 224]
|
||||
return [batch_size, 224, 224, 3]
|
||||
|
||||
|
||||
def random_batch(batch_size):
|
||||
images = np.random.rand(*image_shape(batch_size)).astype(np.float32)
|
||||
num_classes = 1000
|
||||
labels = np.random.randint(
|
||||
low=0, high=num_classes, size=[batch_size]).astype(np.int32)
|
||||
one_hot = np.zeros((batch_size, num_classes)).astype(np.float32)
|
||||
one_hot[np.arange(batch_size), labels] = 1.
|
||||
return images, one_hot
|
||||
|
||||
|
||||
class ResNet50GraphTest(tf.test.TestCase):
|
||||
|
||||
def testApply(self):
|
||||
batch_size = 64
|
||||
with tf.Graph().as_default():
|
||||
images = tf.placeholder(tf.float32, image_shape(None))
|
||||
model = resnet50.ResNet50(data_format())
|
||||
predictions = model(images)
|
||||
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(init)
|
||||
np_images, _ = random_batch(batch_size)
|
||||
out = sess.run(predictions, feed_dict={images: np_images})
|
||||
self.assertAllEqual([64, 1000], out.shape)
|
||||
|
||||
def testTrainWithSummary(self):
|
||||
with tf.Graph().as_default():
|
||||
images = tf.placeholder(tf.float32, image_shape(None), name='images')
|
||||
labels = tf.placeholder(tf.float32, [None, 1000], name='labels')
|
||||
|
||||
tf.train.get_or_create_global_step()
|
||||
logdir = tempfile.mkdtemp()
|
||||
with tf.contrib.summary.always_record_summaries():
|
||||
with tf.contrib.summary.create_summary_file_writer(
|
||||
logdir, max_queue=0,
|
||||
name='t0').as_default():
|
||||
model = resnet50.ResNet50(data_format())
|
||||
logits = model(images, training=True)
|
||||
loss = tf.losses.softmax_cross_entropy(
|
||||
logits=logits, onehot_labels=labels)
|
||||
tf.contrib.summary.scalar(name='loss', tensor=loss)
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
|
||||
train_op = optimizer.minimize(loss)
|
||||
|
||||
init = tf.global_variables_initializer()
|
||||
self.assertEqual(321, len(tf.global_variables()))
|
||||
|
||||
batch_size = 32
|
||||
with tf.Session() as sess:
|
||||
sess.run(init)
|
||||
sess.run(tf.contrib.summary.summary_writer_initializer_op())
|
||||
np_images, np_labels = random_batch(batch_size)
|
||||
sess.run([train_op, tf.contrib.summary.all_summary_ops()],
|
||||
feed_dict={images: np_images, labels: np_labels})
|
||||
|
||||
events = summary_test_util.events_from_file(logdir)
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].tag, 'loss')
|
||||
|
||||
|
||||
class ResNet50Benchmarks(tf.test.Benchmark):
|
||||
|
||||
def _report(self, label, start, num_iters, batch_size):
|
||||
avg_time = (time.time() - start) / num_iters
|
||||
dev = 'gpu' if tf.test.is_gpu_available() else 'cpu'
|
||||
name = 'graph_%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format())
|
||||
extras = {'examples_per_sec': batch_size / avg_time}
|
||||
self.report_benchmark(
|
||||
iters=num_iters, wall_time=avg_time, name=name, extras=extras)
|
||||
|
||||
def benchmark_graph_apply(self):
|
||||
with tf.Graph().as_default():
|
||||
images = tf.placeholder(tf.float32, image_shape(None))
|
||||
model = resnet50.ResNet50(data_format())
|
||||
predictions = model(images)
|
||||
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
batch_size = 64
|
||||
with tf.Session() as sess:
|
||||
sess.run(init)
|
||||
np_images, _ = random_batch(batch_size)
|
||||
num_burn, num_iters = (3, 30)
|
||||
for _ in range(num_burn):
|
||||
sess.run(predictions, feed_dict={images: np_images})
|
||||
start = time.time()
|
||||
for _ in range(num_iters):
|
||||
# Comparison with the eager execution benchmark in resnet50_test.py
|
||||
# isn't entirely fair as the time here includes the cost of copying
|
||||
# the feeds from CPU memory to GPU.
|
||||
sess.run(predictions, feed_dict={images: np_images})
|
||||
self._report('apply', start, num_iters, batch_size)
|
||||
|
||||
def benchmark_graph_train(self):
|
||||
for batch_size in [16, 32, 64]:
|
||||
with tf.Graph().as_default():
|
||||
np_images, np_labels = random_batch(batch_size)
|
||||
dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat()
|
||||
(images, labels) = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
model = resnet50.ResNet50(data_format())
|
||||
logits = model(images, training=True)
|
||||
loss = tf.losses.softmax_cross_entropy(
|
||||
logits=logits, onehot_labels=labels)
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
train_op = optimizer.minimize(loss)
|
||||
|
||||
init = tf.global_variables_initializer()
|
||||
with tf.Session() as sess:
|
||||
sess.run(init)
|
||||
(num_burn, num_iters) = (5, 10)
|
||||
for _ in range(num_burn):
|
||||
sess.run(train_op)
|
||||
start = time.time()
|
||||
for _ in range(num_iters):
|
||||
sess.run(train_op)
|
||||
self._report('train', start, num_iters, batch_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
@ -0,0 +1,234 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests and benchmarks for the ResNet50 model, executed eagerly."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import tensorflow.contrib.eager as tfe
|
||||
from tensorflow.contrib.eager.python.examples.resnet50 import resnet50
|
||||
from tensorflow.contrib.summary import summary_test_util
|
||||
from tensorflow.python.client import device_lib
|
||||
|
||||
|
||||
def device_and_data_format():
|
||||
return ('/gpu:0', 'channels_first') if tfe.num_gpus() else ('/cpu:0',
|
||||
'channels_last')
|
||||
|
||||
|
||||
def random_batch(batch_size):
|
||||
_, data_format = device_and_data_format()
|
||||
|
||||
shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3)
|
||||
shape = (batch_size,) + shape
|
||||
|
||||
num_classes = 1000
|
||||
images = tf.random_uniform(shape)
|
||||
labels = tf.random_uniform(
|
||||
[batch_size], minval=0, maxval=num_classes, dtype=tf.int32)
|
||||
one_hot = tf.one_hot(labels, num_classes)
|
||||
|
||||
return images, one_hot
|
||||
|
||||
|
||||
def train_one_step(model, images, labels, optimizer):
|
||||
|
||||
def model_loss():
|
||||
logits = model(images, training=True)
|
||||
loss = tf.losses.softmax_cross_entropy(
|
||||
logits=logits, onehot_labels=labels)
|
||||
tf.contrib.summary.scalar(name='loss', tensor=loss)
|
||||
return loss
|
||||
|
||||
optimizer.minimize(model_loss)
|
||||
|
||||
|
||||
class ResNet50Test(tf.test.TestCase):
|
||||
|
||||
def test_apply(self):
|
||||
device, data_format = device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format)
|
||||
with tf.device(device):
|
||||
images, _ = random_batch(2)
|
||||
output = model(images)
|
||||
self.assertEqual((2, 1000), output.shape)
|
||||
|
||||
def test_apply_no_top(self):
|
||||
device, data_format = device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format, include_top=False)
|
||||
with tf.device(device):
|
||||
images, _ = random_batch(2)
|
||||
output = model(images)
|
||||
output_shape = ((2, 2048, 1, 1)
|
||||
if data_format == 'channels_first' else (2, 1, 1, 2048))
|
||||
self.assertEqual(output_shape, output.shape)
|
||||
|
||||
def test_apply_with_pooling(self):
|
||||
device, data_format = device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
|
||||
with tf.device(device):
|
||||
images, _ = random_batch(2)
|
||||
output = model(images)
|
||||
self.assertEqual((2, 2048), output.shape)
|
||||
|
||||
def test_train(self):
|
||||
device, data_format = device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format)
|
||||
tf.train.get_or_create_global_step()
|
||||
logdir = tempfile.mkdtemp()
|
||||
with tf.contrib.summary.create_summary_file_writer(
|
||||
logdir, max_queue=0,
|
||||
name='t0').as_default(), tf.contrib.summary.always_record_summaries():
|
||||
with tf.device(device):
|
||||
optimizer = tf.train.GradientDescentOptimizer(0.1)
|
||||
images, labels = random_batch(2)
|
||||
train_one_step(model, images, labels, optimizer)
|
||||
self.assertEqual(320, len(model.variables))
|
||||
events = summary_test_util.events_from_file(logdir)
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[1].summary.value[0].tag, 'loss')
|
||||
|
||||
def test_no_garbage(self):
|
||||
device, data_format = device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format)
|
||||
optimizer = tf.train.GradientDescentOptimizer(0.1)
|
||||
with tf.device(device):
|
||||
images, labels = random_batch(2)
|
||||
gc.disable()
|
||||
# Warm up. Note that this first run does create significant amounts of
|
||||
# garbage to be collected. The hope is that this is a build-only effect,
|
||||
# and a subsequent training loop will create nothing which needs to be
|
||||
# collected.
|
||||
train_one_step(model, images, labels, optimizer)
|
||||
gc.collect()
|
||||
previous_gc_debug_flags = gc.get_debug()
|
||||
gc.set_debug(gc.DEBUG_SAVEALL)
|
||||
for _ in range(2):
|
||||
# Run twice to ensure that garbage that is created on the first
|
||||
# iteration is no longer accessible.
|
||||
train_one_step(model, images, labels, optimizer)
|
||||
gc.collect()
|
||||
# There should be no garbage requiring collection.
|
||||
self.assertEqual(0, len(gc.garbage))
|
||||
gc.set_debug(previous_gc_debug_flags)
|
||||
gc.enable()
|
||||
|
||||
|
||||
class MockIterator(object):
|
||||
|
||||
def __init__(self, tensors):
|
||||
self._tensors = [tf.identity(x) for x in tensors]
|
||||
|
||||
def next(self):
|
||||
return self._tensors
|
||||
|
||||
|
||||
class ResNet50Benchmarks(tf.test.Benchmark):
|
||||
|
||||
def _train_batch_sizes(self):
|
||||
"""Choose batch sizes based on GPU capability."""
|
||||
for device in device_lib.list_local_devices():
|
||||
if 'GPU:0' in device.name:
|
||||
# Avoid OOM errors with larger batch sizes, which seem to cause errors
|
||||
# later on even if caught.
|
||||
#
|
||||
# TODO(allenl): Base this on device memory; memory limit information
|
||||
# during the test seems to exclude the amount TensorFlow has allocated,
|
||||
# which isn't useful.
|
||||
if 'K20' in device.physical_device_desc:
|
||||
return (16,)
|
||||
if 'P100' in device.physical_device_desc:
|
||||
return (16, 32, 64)
|
||||
return (16, 32)
|
||||
|
||||
def _report(self, label, start, num_iters, device, batch_size, data_format):
|
||||
avg_time = (time.time() - start) / num_iters
|
||||
dev = 'cpu' if 'cpu' in device else 'gpu'
|
||||
name = '%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format)
|
||||
extras = {'examples_per_sec': batch_size / avg_time}
|
||||
self.report_benchmark(
|
||||
iters=num_iters, wall_time=avg_time, name=name, extras=extras)
|
||||
|
||||
def _force_gpu_sync(self):
|
||||
# If this function is called in the context of a GPU device
|
||||
# (e.g., inside a 'with tf.device("/gpu:0")' block)
|
||||
# then this will force a copy from CPU->GPU->CPU, which forces
|
||||
# a sync. This is a roundabout way, yes.
|
||||
tf.constant(1.).cpu()
|
||||
|
||||
def benchmark_eager_apply(self):
|
||||
device, data_format = device_and_data_format()
|
||||
model = resnet50.ResNet50(data_format)
|
||||
batch_size = 64
|
||||
num_burn = 5
|
||||
num_iters = 30
|
||||
with tf.device(device):
|
||||
images, _ = random_batch(batch_size)
|
||||
for _ in xrange(num_burn):
|
||||
model(images).cpu()
|
||||
gc.collect()
|
||||
start = time.time()
|
||||
for _ in xrange(num_iters):
|
||||
model(images).cpu()
|
||||
self._report('eager_apply', start, num_iters, device, batch_size,
|
||||
data_format)
|
||||
|
||||
def _benchmark_eager_train(self, label, make_iterator):
|
||||
device, data_format = device_and_data_format()
|
||||
for batch_size in self._train_batch_sizes():
|
||||
(images, labels) = random_batch(batch_size)
|
||||
num_burn = 3
|
||||
num_iters = 10
|
||||
model = resnet50.ResNet50(data_format)
|
||||
optimizer = tf.train.GradientDescentOptimizer(0.1)
|
||||
|
||||
with tf.device(device):
|
||||
iterator = make_iterator((images, labels))
|
||||
for _ in xrange(num_burn):
|
||||
(images, labels) = iterator.next()
|
||||
train_one_step(model, images, labels, optimizer)
|
||||
self._force_gpu_sync()
|
||||
gc.collect()
|
||||
|
||||
start = time.time()
|
||||
for _ in xrange(num_iters):
|
||||
(images, labels) = iterator.next()
|
||||
train_one_step(model, images, labels, optimizer)
|
||||
self._force_gpu_sync()
|
||||
self._report(label, start, num_iters, device, batch_size, data_format)
|
||||
|
||||
def benchmark_eager_train(self):
|
||||
self._benchmark_eager_train('eager_train', MockIterator)
|
||||
|
||||
def benchmark_eager_train_datasets(self):
|
||||
|
||||
def make_iterator(tensors):
|
||||
with tf.device('/device:CPU:0'):
|
||||
ds = tf.data.Dataset.from_tensors(tensors).repeat()
|
||||
return tfe.Iterator(ds)
|
||||
|
||||
self._benchmark_eager_train('eager_train_dataset', make_iterator)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tfe.enable_eager_execution()
|
||||
tf.test.main()
|
||||
26
tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD
Normal file
26
tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_binary(
|
||||
name = "rnn_colorbot",
|
||||
srcs = ["rnn_colorbot.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "rnn_colorbot_test",
|
||||
srcs = ["rnn_colorbot_test.py"],
|
||||
additional_deps = [
|
||||
":rnn_colorbot",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
RNN Colorbot: An RNN that predicts colors using eager execution.
|
||||
|
||||
To train and generate colors, run:
|
||||
|
||||
```
|
||||
python rnn_colorbot.py
|
||||
```
|
||||
|
||||
This example shows how to:
|
||||
1. read, process, (one-hot) encode, and pad text data via the
|
||||
Datasets API;
|
||||
2. build a trainable model;
|
||||
3. implement a multi-layer RNN using Python control flow
|
||||
constructs (e.g., a for loop);
|
||||
4. train a model using an iterative gradient-based method; and
|
||||
5. log training and evaluation loss for consumption by TensorBoard
|
||||
(to view summaries, use: tensorboard --log_dir=<dir>/summaries).
|
||||
|
||||
The data used in this example is licensed under the Creative Commons
|
||||
Attribution-ShareAlike License and is available at
|
||||
https://en.wikipedia.org/wiki/List_of_colors:_A-F
|
||||
https://en.wikipedia.org/wiki/List_of_colors:_G-M
|
||||
https://en.wikipedia.org/wiki/List_of_colors:_N-Z
|
||||
|
||||
This example was adapted from
|
||||
https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot
|
||||
|
|
@ -0,0 +1,338 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""TensorFlow Eager Execution Example: RNN Colorbot.
|
||||
|
||||
This example builds, trains, and evaluates a multi-layer RNN that can be
|
||||
run with eager execution enabled. The RNN is trained to map color names to
|
||||
their RGB values: it takes as input a one-hot encoded character sequence and
|
||||
outputs a three-tuple (R, G, B) (scaled by 1/255).
|
||||
|
||||
For example, say we'd like the RNN Colorbot to generate the RGB values for the
|
||||
color white. To represent our query in a form that the Colorbot could
|
||||
understand, we would create a sequence of five 256-long vectors encoding the
|
||||
ASCII values of the characters in "white". The first vector in our sequence
|
||||
would be 0 everywhere except for the ord("w")-th position, where it would be
|
||||
1, the second vector would be 0 everywhere except for the
|
||||
ord("h")-th position, where it would be 1, and similarly for the remaining three
|
||||
vectors. We refer to such indicator vectors as "one-hot encodings" of
|
||||
characters. After consuming these vectors, a well-trained Colorbot would output
|
||||
the three tuple (1, 1, 1), since the RGB values for white are (255, 255, 255).
|
||||
We are of course free to ask the colorbot to generate colors for any string we'd
|
||||
like, such as "steel gray," "tensorflow orange," or "green apple," though
|
||||
your mileage may vary as your queries increase in creativity.
|
||||
|
||||
This example shows how to:
|
||||
1. read, process, (one-hot) encode, and pad text data via the
|
||||
Datasets API;
|
||||
2. build a trainable model;
|
||||
3. implement a multi-layer RNN using Python control flow
|
||||
constructs (e.g., a for loop);
|
||||
4. train a model using an iterative gradient-based method; and
|
||||
|
||||
The data used in this example is licensed under the Creative Commons
|
||||
Attribution-ShareAlike License and is available at
|
||||
https://en.wikipedia.org/wiki/List_of_colors:_A-F
|
||||
https://en.wikipedia.org/wiki/List_of_colors:_G-M
|
||||
https://en.wikipedia.org/wiki/List_of_colors:_N-Z
|
||||
|
||||
This example was adapted from
|
||||
https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import six
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.eager.python import tfe
|
||||
from tensorflow.python.eager import context
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
|
||||
HAS_MATPLOTLIB = True
|
||||
except ImportError:
|
||||
HAS_MATPLOTLIB = False
|
||||
|
||||
|
||||
def parse(line):
|
||||
"""Parse a line from the colors dataset."""
|
||||
|
||||
# Each line of the dataset is comma-separated and formatted as
|
||||
# color_name, r, g, b
|
||||
# so `items` is a list [color_name, r, g, b].
|
||||
items = tf.string_split([line], ",").values
|
||||
rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.
|
||||
# Represent the color name as a one-hot encoded character sequence.
|
||||
color_name = items[0]
|
||||
chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)
|
||||
# The sequence length is needed by our RNN.
|
||||
length = tf.cast(tf.shape(chars)[0], dtype=tf.int64)
|
||||
return rgb, chars, length
|
||||
|
||||
|
||||
def load_dataset(data_dir, url, batch_size):
|
||||
"""Loads the colors data at path into a PaddedDataset."""
|
||||
|
||||
# Downloads data at url into data_dir/basename(url). The dataset has a header
|
||||
# row (color_name, r, g, b) followed by comma-separated lines.
|
||||
path = tf.contrib.learn.datasets.base.maybe_download(
|
||||
os.path.basename(url), data_dir, url)
|
||||
|
||||
# This chain of commands loads our data by:
|
||||
# 1. skipping the header; (.skip(1))
|
||||
# 2. parsing the subsequent lines; (.map(parse))
|
||||
# 3. shuffling the data; (.shuffle(...))
|
||||
# 3. grouping the data into padded batches (.padded_batch(...)).
|
||||
dataset = tf.data.TextLineDataset(path).skip(1).map(parse).shuffle(
|
||||
buffer_size=10000).padded_batch(
|
||||
batch_size, padded_shapes=([None], [None, None], []))
|
||||
return dataset
|
||||
|
||||
|
||||
# pylint: disable=not-callable
|
||||
class RNNColorbot(tfe.Network):
|
||||
"""Multi-layer (LSTM) RNN that regresses on real-valued vector labels.
|
||||
"""
|
||||
|
||||
def __init__(self, rnn_cell_sizes, label_dimension, keep_prob):
|
||||
"""Constructs an RNNColorbot.
|
||||
|
||||
Args:
|
||||
rnn_cell_sizes: list of integers denoting the size of each LSTM cell in
|
||||
the RNN; rnn_cell_sizes[i] is the size of the i-th layer cell
|
||||
label_dimension: the length of the labels on which to regress
|
||||
keep_prob: (1 - dropout probability); dropout is applied to the outputs of
|
||||
each LSTM layer
|
||||
"""
|
||||
super(RNNColorbot, self).__init__(name="")
|
||||
self.label_dimension = label_dimension
|
||||
self.keep_prob = keep_prob
|
||||
|
||||
# Note the calls to `track_layer` below; these calls register the layers as
|
||||
# network components that house trainable variables.
|
||||
self.cells = [
|
||||
self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(size))
|
||||
for size in rnn_cell_sizes
|
||||
]
|
||||
self.relu = self.track_layer(
|
||||
tf.layers.Dense(label_dimension, activation=tf.nn.relu, name="relu"))
|
||||
|
||||
def call(self, chars, sequence_length, training=False):
|
||||
"""Implements the RNN logic and prediction generation.
|
||||
|
||||
Args:
|
||||
chars: a Tensor of dimension [batch_size, time_steps, 256] holding a
|
||||
batch of one-hot encoded color names
|
||||
sequence_length: a Tensor of dimension [batch_size] holding the length
|
||||
of each character sequence (i.e., color name)
|
||||
training: whether the invocation is happening during training
|
||||
|
||||
Returns:
|
||||
A tensor of dimension [batch_size, label_dimension] that is produced by
|
||||
passing chars through a multi-layer RNN and applying a ReLU to the final
|
||||
hidden state.
|
||||
"""
|
||||
# Transpose the first and second dimensions so that chars is of shape
|
||||
# [time_steps, batch_size, dimension].
|
||||
chars = tf.transpose(chars, [1, 0, 2])
|
||||
# The outer loop cycles through the layers of the RNN; the inner loop
|
||||
# executes the time steps for a particular layer.
|
||||
batch_size = int(chars.shape[1])
|
||||
for l in range(len(self.cells)):
|
||||
cell = self.cells[l]
|
||||
outputs = []
|
||||
state = cell.zero_state(batch_size, tf.float32)
|
||||
# Unstack the inputs to obtain a list of batches, one for each time step.
|
||||
chars = tf.unstack(chars, axis=0)
|
||||
for ch in chars:
|
||||
output, state = cell(ch, state)
|
||||
outputs.append(output)
|
||||
# The outputs of this layer are the inputs of the subsequent layer.
|
||||
chars = tf.stack(outputs, axis=0)
|
||||
if training:
|
||||
chars = tf.nn.dropout(chars, self.keep_prob)
|
||||
# Extract the correct output (i.e., hidden state) for each example. All the
|
||||
# character sequences in this batch were padded to the same fixed length so
|
||||
# that they could be easily fed through the above RNN loop. The
|
||||
# `sequence_length` vector tells us the true lengths of the character
|
||||
# sequences, letting us obtain for each sequence the hidden state that was
|
||||
# generated by its non-padding characters.
|
||||
batch_range = [i for i in range(batch_size)]
|
||||
indices = tf.stack([sequence_length - 1, batch_range], axis=1)
|
||||
hidden_states = tf.gather_nd(chars, indices)
|
||||
return self.relu(hidden_states)
|
||||
|
||||
|
||||
def loss(labels, predictions):
|
||||
"""Computes mean squared loss."""
|
||||
return tf.reduce_mean(tf.square(predictions - labels))
|
||||
|
||||
|
||||
def test(model, eval_data):
|
||||
"""Computes the average loss on eval_data, which should be a Dataset."""
|
||||
avg_loss = tfe.metrics.Mean("loss")
|
||||
for (labels, chars, sequence_length) in tfe.Iterator(eval_data):
|
||||
predictions = model(chars, sequence_length, training=False)
|
||||
avg_loss(loss(labels, predictions))
|
||||
print("eval/loss: %.6f\n" % avg_loss.result())
|
||||
with tf.contrib.summary.always_record_summaries():
|
||||
tf.contrib.summary.scalar("loss", avg_loss.result())
|
||||
|
||||
|
||||
def train_one_epoch(model, optimizer, train_data, log_interval=10):
|
||||
"""Trains model on train_data using optimizer."""
|
||||
|
||||
tf.train.get_or_create_global_step()
|
||||
|
||||
def model_loss(labels, chars, sequence_length):
|
||||
predictions = model(chars, sequence_length, training=True)
|
||||
loss_value = loss(labels, predictions)
|
||||
tf.contrib.summary.scalar("loss", loss_value)
|
||||
return loss_value
|
||||
|
||||
for (batch, (labels, chars, sequence_length)) in enumerate(
|
||||
tfe.Iterator(train_data)):
|
||||
with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval):
|
||||
batch_model_loss = functools.partial(model_loss, labels, chars,
|
||||
sequence_length)
|
||||
optimizer.minimize(
|
||||
batch_model_loss, global_step=tf.train.get_global_step())
|
||||
if log_interval and batch % log_interval == 0:
|
||||
print("train/batch #%d\tloss: %.6f" % (batch, batch_model_loss()))
|
||||
|
||||
|
||||
SOURCE_TRAIN_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv"
|
||||
SOURCE_TEST_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv"
|
||||
|
||||
|
||||
def main(_):
|
||||
data_dir = os.path.join(FLAGS.dir, "data")
|
||||
train_data = load_dataset(
|
||||
data_dir=data_dir, url=SOURCE_TRAIN_URL, batch_size=FLAGS.batch_size)
|
||||
eval_data = load_dataset(
|
||||
data_dir=data_dir, url=SOURCE_TEST_URL, batch_size=FLAGS.batch_size)
|
||||
|
||||
model = RNNColorbot(
|
||||
rnn_cell_sizes=FLAGS.rnn_cell_sizes,
|
||||
label_dimension=3,
|
||||
keep_prob=FLAGS.keep_probability)
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
|
||||
|
||||
if FLAGS.no_gpu or tfe.num_gpus() <= 0:
|
||||
print(tfe.num_gpus())
|
||||
device = "/cpu:0"
|
||||
else:
|
||||
device = "/gpu:0"
|
||||
print("Using device %s." % device)
|
||||
|
||||
log_dir = os.path.join(FLAGS.dir, "summaries")
|
||||
tf.gfile.MakeDirs(log_dir)
|
||||
train_summary_writer = tf.contrib.summary.create_summary_file_writer(
|
||||
os.path.join(log_dir, "train"), flush_secs=10)
|
||||
test_summary_writer = tf.contrib.summary.create_summary_file_writer(
|
||||
os.path.join(log_dir, "eval"), flush_secs=10, name="eval")
|
||||
|
||||
with tf.device(device):
|
||||
for epoch in range(FLAGS.num_epochs):
|
||||
start = time.time()
|
||||
with train_summary_writer.as_default():
|
||||
train_one_epoch(model, optimizer, train_data, FLAGS.log_interval)
|
||||
end = time.time()
|
||||
print("train/time for epoch #%d: %.2f" % (epoch, end - start))
|
||||
with test_summary_writer.as_default():
|
||||
test(model, eval_data)
|
||||
|
||||
print("Colorbot is ready to generate colors!")
|
||||
while True:
|
||||
try:
|
||||
color_name = six.moves.input(
|
||||
"Give me a color name (or press enter to exit): ")
|
||||
except EOFError:
|
||||
return
|
||||
|
||||
if not color_name:
|
||||
return
|
||||
|
||||
_, chars, length = parse(color_name)
|
||||
with tf.device(device):
|
||||
(chars, length) = (tf.identity(chars), tf.identity(length))
|
||||
chars = tf.expand_dims(chars, 0)
|
||||
length = tf.expand_dims(length, 0)
|
||||
preds = tf.unstack(model(chars, length, training=False)[0])
|
||||
|
||||
# Predictions cannot be negative, as they are generated by a ReLU layer;
|
||||
# they may, however, be greater than 1.
|
||||
clipped_preds = tuple(min(float(p), 1.0) for p in preds)
|
||||
rgb = tuple(int(p * 255) for p in clipped_preds)
|
||||
print("rgb:", rgb)
|
||||
data = [[clipped_preds]]
|
||||
if HAS_MATPLOTLIB:
|
||||
plt.imshow(data)
|
||||
plt.title(color_name)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dir",
|
||||
type=str,
|
||||
default="/tmp/rnn_colorbot/",
|
||||
help="Directory to download data files and save logs.")
|
||||
parser.add_argument(
|
||||
"--log_interval",
|
||||
type=int,
|
||||
default=10,
|
||||
metavar="N",
|
||||
help="Log training loss every log_interval batches.")
|
||||
parser.add_argument(
|
||||
"--num_epochs", type=int, default=20, help="Number of epochs to train.")
|
||||
parser.add_argument(
|
||||
"--rnn_cell_sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[256, 128],
|
||||
help="List of sizes for each layer of the RNN.")
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Batch size for training and eval.")
|
||||
parser.add_argument(
|
||||
"--keep_probability",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Keep probability for dropout between layers.")
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="Learning rate to be used during training.")
|
||||
parser.add_argument(
|
||||
"--no_gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disables GPU usage even if a GPU is available.")
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tfe.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.eager.python import tfe
|
||||
from tensorflow.contrib.eager.python.examples.rnn_colorbot import rnn_colorbot
|
||||
|
||||
|
||||
LABEL_DIMENSION = 5
|
||||
|
||||
|
||||
def device():
|
||||
return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0"
|
||||
|
||||
|
||||
def random_dataset():
|
||||
batch_size = 64
|
||||
time_steps = 10
|
||||
alphabet = 50
|
||||
chars = tf.one_hot(
|
||||
tf.random_uniform(
|
||||
[batch_size, time_steps], minval=0, maxval=alphabet, dtype=tf.int32),
|
||||
alphabet)
|
||||
sequence_length = tf.constant(
|
||||
[time_steps for _ in range(batch_size)], dtype=tf.int64)
|
||||
labels = tf.random_normal([batch_size, LABEL_DIMENSION])
|
||||
return tf.data.Dataset.from_tensors((labels, chars, sequence_length))
|
||||
|
||||
|
||||
class RNNColorbotTest(tf.test.TestCase):
|
||||
|
||||
def testTrainOneEpoch(self):
|
||||
model = rnn_colorbot.RNNColorbot(
|
||||
rnn_cell_sizes=[256, 128, 64],
|
||||
label_dimension=LABEL_DIMENSION,
|
||||
keep_prob=1.0)
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=.01)
|
||||
dataset = random_dataset()
|
||||
with tf.device(device()):
|
||||
rnn_colorbot.train_one_epoch(model, optimizer, dataset)
|
||||
|
||||
def testTest(self):
|
||||
model = rnn_colorbot.RNNColorbot(
|
||||
rnn_cell_sizes=[256],
|
||||
label_dimension=LABEL_DIMENSION,
|
||||
keep_prob=1.0)
|
||||
dataset = random_dataset()
|
||||
with tf.device(device()):
|
||||
rnn_colorbot.test(model, dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tfe.enable_eager_execution()
|
||||
tf.test.main()
|
||||
35
tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD
Normal file
35
tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_binary(
|
||||
name = "rnn_ptb",
|
||||
srcs = ["rnn_ptb.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "rnn_ptb_test",
|
||||
srcs = ["rnn_ptb_test.py"],
|
||||
additional_deps = [
|
||||
":rnn_ptb",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "rnn_ptb_graph_test",
|
||||
srcs = ["rnn_ptb_graph_test.py"],
|
||||
additional_deps = [
|
||||
":rnn_ptb",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
42
tensorflow/contrib/eager/python/examples/rnn_ptb/README.md
Normal file
42
tensorflow/contrib/eager/python/examples/rnn_ptb/README.md
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
Recurrent Neural Network model.
|
||||
|
||||
Implements a language modeling network described in
|
||||
https://www.tensorflow.org/tutorials/recurrent
|
||||
that is compatible with (and idiomatic for) eager execution.
|
||||
|
||||
To run:
|
||||
|
||||
- Download and extract the Penn Treebank dataset from
|
||||
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
|
||||
|
||||
```sh
|
||||
tar xvzf simple-examples.tgz -C /tmp
|
||||
```
|
||||
|
||||
- Run: `python rnn_ptb.py --data-dir=/tmp/simple-examples/data`
|
||||
|
||||
|
||||
Benchmarks (using synthetic data):
|
||||
|
||||
```
|
||||
# Using eager execution
|
||||
bazel run -c opt --config=cuda :rnn_ptb_test -- --benchmarks=.
|
||||
|
||||
# Using graph execution
|
||||
bazel run -c opt --config=cuda :rnn_ptb_graph_test -- --benchmarks=.
|
||||
```
|
||||
|
||||
(Or remove the `--config=cuda` flag for running on CPU instead of GPU).
|
||||
|
||||
On October 31, 2017, the benchmarks demostrated slightly better performance
|
||||
(3-6%) for graph execution over eager execution for this particular model when
|
||||
using a single NVIDIA Titan X (Pascal) GPU on a host with an Intel Xeon E5-1650
|
||||
CPU @ 3.50GHz and a batch size of 32.
|
||||
|
||||
| Benchmark name | examples/second |
|
||||
| ------------------------------------ | --------------- |
|
||||
| eager_cudnn_train_large_gpu_batch_20 | 938 |
|
||||
| graph_cudnn_train_large_gpu_batch_20 | 971 |
|
||||
| eager_cudnn_train_small_gpu_batch_20 | 2433 |
|
||||
| graph_cudnn_train_small_gpu_batch_20 | 2585 |
|
||||
|
||||
348
tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
Normal file
348
tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
Normal file
|
|
@ -0,0 +1,348 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Penn Treebank RNN model definition compatible with eager execution.
|
||||
|
||||
Model similar to
|
||||
https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb
|
||||
|
||||
Usage: python ./rnn_ptb.py --data-path=<path_to_dataset>
|
||||
|
||||
Penn Treebank (PTB) dataset from:
|
||||
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn
|
||||
from tensorflow.contrib.eager.python import tfe
|
||||
|
||||
|
||||
class RNN(tfe.Network):
|
||||
"""A static RNN.
|
||||
|
||||
Similar to tf.nn.static_rnn, implemented as a tf.layer.Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim, num_layers, keep_ratio):
|
||||
super(RNN, self).__init__()
|
||||
self.keep_ratio = keep_ratio
|
||||
for _ in range(num_layers):
|
||||
self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim))
|
||||
|
||||
def call(self, input_seq, training):
|
||||
batch_size = int(input_seq.shape[1])
|
||||
for c in self.layers:
|
||||
state = c.zero_state(batch_size, tf.float32)
|
||||
outputs = []
|
||||
input_seq = tf.unstack(input_seq, num=int(input_seq.shape[0]), axis=0)
|
||||
for inp in input_seq:
|
||||
output, state = c(inp, state)
|
||||
outputs.append(output)
|
||||
|
||||
input_seq = tf.stack(outputs, axis=0)
|
||||
if training:
|
||||
input_seq = tf.nn.dropout(input_seq, self.keep_ratio)
|
||||
return input_seq, None
|
||||
|
||||
|
||||
class Embedding(tf.layers.Layer):
|
||||
"""An Embedding layer."""
|
||||
|
||||
def __init__(self, vocab_size, embedding_dim, **kwargs):
|
||||
super(Embedding, self).__init__(**kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def build(self, _):
|
||||
self.embedding = self.add_variable(
|
||||
"embedding_kernel",
|
||||
shape=[self.vocab_size, self.embedding_dim],
|
||||
dtype=tf.float32,
|
||||
initializer=tf.random_uniform_initializer(-0.1, 0.1),
|
||||
trainable=True)
|
||||
|
||||
def call(self, x):
|
||||
return tf.nn.embedding_lookup(self.embedding, x)
|
||||
|
||||
|
||||
class PTBModel(tfe.Network):
|
||||
"""LSTM for word language modelling.
|
||||
|
||||
Model described in:
|
||||
(Zaremba, et. al.) Recurrent Neural Network Regularization
|
||||
http://arxiv.org/abs/1409.2329
|
||||
|
||||
See also:
|
||||
https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
embedding_dim,
|
||||
hidden_dim,
|
||||
num_layers,
|
||||
dropout_ratio,
|
||||
use_cudnn_rnn=True):
|
||||
super(PTBModel, self).__init__()
|
||||
|
||||
self.keep_ratio = 1 - dropout_ratio
|
||||
self.use_cudnn_rnn = use_cudnn_rnn
|
||||
self.embedding = self.track_layer(Embedding(vocab_size, embedding_dim))
|
||||
|
||||
if self.use_cudnn_rnn:
|
||||
self.rnn = cudnn_rnn.CudnnLSTM(
|
||||
num_layers, hidden_dim, dropout=dropout_ratio)
|
||||
else:
|
||||
self.rnn = RNN(hidden_dim, num_layers, self.keep_ratio)
|
||||
self.track_layer(self.rnn)
|
||||
|
||||
self.linear = self.track_layer(
|
||||
tf.layers.Dense(
|
||||
vocab_size,
|
||||
kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1)))
|
||||
self._output_shape = [-1, embedding_dim]
|
||||
|
||||
def call(self, input_seq, training):
|
||||
"""Run the forward pass of PTBModel.
|
||||
|
||||
Args:
|
||||
input_seq: [length, batch] shape int64 tensor.
|
||||
training: Is this a training call.
|
||||
Returns:
|
||||
outputs tensors of inference.
|
||||
"""
|
||||
y = self.embedding(input_seq)
|
||||
if training:
|
||||
y = tf.nn.dropout(y, self.keep_ratio)
|
||||
y, _ = self.rnn(y, training=training)
|
||||
return self.linear(tf.reshape(y, self._output_shape))
|
||||
|
||||
|
||||
def clip_gradients(grads_and_vars, clip_ratio):
|
||||
gradients, variables = zip(*grads_and_vars)
|
||||
clipped, _ = tf.clip_by_global_norm(gradients, clip_ratio)
|
||||
return zip(clipped, variables)
|
||||
|
||||
|
||||
def loss_fn(model, inputs, targets, training):
|
||||
labels = tf.reshape(targets, [-1])
|
||||
outputs = model(inputs, training)
|
||||
return tf.reduce_mean(
|
||||
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=labels, logits=outputs))
|
||||
|
||||
|
||||
def _divide_into_batches(data, batch_size):
|
||||
"""Convert a sequence to a batch of sequences."""
|
||||
nbatch = data.shape[0] // batch_size
|
||||
data = data[:nbatch * batch_size]
|
||||
data = data.reshape(batch_size, -1).transpose()
|
||||
return data
|
||||
|
||||
|
||||
def _get_batch(data, i, seq_len):
|
||||
slen = min(seq_len, data.shape[0] - 1 - i)
|
||||
inputs = data[i:i + slen, :]
|
||||
target = data[i + 1:i + 1 + slen, :]
|
||||
return tf.constant(inputs), tf.constant(target)
|
||||
|
||||
|
||||
def evaluate(model, data):
|
||||
"""evaluate an epoch."""
|
||||
total_loss = 0.0
|
||||
total_batches = 0
|
||||
start = time.time()
|
||||
for _, i in enumerate(range(0, data.shape[0] - 1, FLAGS.seq_len)):
|
||||
inp, target = _get_batch(data, i, FLAGS.seq_len)
|
||||
loss = loss_fn(model, inp, target, training=False)
|
||||
total_loss += loss.numpy()
|
||||
total_batches += 1
|
||||
time_in_ms = (time.time() - start) * 1000
|
||||
sys.stderr.write("eval loss %.2f (eval took %d ms)\n" %
|
||||
(total_loss / total_batches, time_in_ms))
|
||||
return total_loss
|
||||
|
||||
|
||||
def train(model, optimizer, train_data, sequence_length, clip_ratio):
|
||||
"""training an epoch."""
|
||||
|
||||
def model_loss(inputs, targets):
|
||||
return loss_fn(model, inputs, targets, training=True)
|
||||
|
||||
grads = tfe.implicit_gradients(model_loss)
|
||||
|
||||
total_time = 0
|
||||
for batch, i in enumerate(range(0, train_data.shape[0] - 1, sequence_length)):
|
||||
train_seq, train_target = _get_batch(train_data, i, sequence_length)
|
||||
start = time.time()
|
||||
optimizer.apply_gradients(
|
||||
clip_gradients(grads(train_seq, train_target), clip_ratio))
|
||||
total_time += (time.time() - start)
|
||||
if batch % 10 == 0:
|
||||
time_in_ms = (total_time * 1000) / (batch + 1)
|
||||
sys.stderr.write("batch %d: training loss %.2f, avg step time %d ms\n" %
|
||||
(batch, model_loss(train_seq, train_target).numpy(),
|
||||
time_in_ms))
|
||||
|
||||
|
||||
class Datasets(object):
|
||||
"""Processed form of the Penn Treebank dataset."""
|
||||
|
||||
def __init__(self, path):
|
||||
"""Load the Penn Treebank dataset.
|
||||
|
||||
Args:
|
||||
path: Path to the data/ directory of the dataset from from Tomas Mikolov's
|
||||
webpage - http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
|
||||
"""
|
||||
|
||||
self.word2idx = {} # string -> integer id
|
||||
self.idx2word = [] # integer id -> word string
|
||||
# Files represented as a list of integer ids (as opposed to list of string
|
||||
# words).
|
||||
self.train = self.tokenize(os.path.join(path, "ptb.train.txt"))
|
||||
self.valid = self.tokenize(os.path.join(path, "ptb.valid.txt"))
|
||||
|
||||
def vocab_size(self):
|
||||
return len(self.idx2word)
|
||||
|
||||
def add(self, word):
|
||||
if word not in self.word2idx:
|
||||
self.idx2word.append(word)
|
||||
self.word2idx[word] = len(self.idx2word) - 1
|
||||
|
||||
def tokenize(self, path):
|
||||
"""Read text file in path and return a list of integer token ids."""
|
||||
tokens = 0
|
||||
with tf.gfile.Open(path, "r") as f:
|
||||
for line in f:
|
||||
words = line.split() + ["<eos>"]
|
||||
tokens += len(words)
|
||||
for word in words:
|
||||
self.add(word)
|
||||
|
||||
# Tokenize file content
|
||||
with tf.gfile.Open(path, "r") as f:
|
||||
ids = np.zeros(tokens).astype(np.int64)
|
||||
token = 0
|
||||
for line in f:
|
||||
words = line.split() + ["<eos>"]
|
||||
for word in words:
|
||||
ids[token] = self.word2idx[word]
|
||||
token += 1
|
||||
|
||||
return ids
|
||||
|
||||
|
||||
def small_model(use_cudnn_rnn):
|
||||
"""Returns a PTBModel with a 'small' configuration."""
|
||||
return PTBModel(
|
||||
vocab_size=10000,
|
||||
embedding_dim=200,
|
||||
hidden_dim=200,
|
||||
num_layers=2,
|
||||
dropout_ratio=0.,
|
||||
use_cudnn_rnn=use_cudnn_rnn)
|
||||
|
||||
|
||||
def large_model(use_cudnn_rnn):
|
||||
"""Returns a PTBModel with a 'large' configuration."""
|
||||
return PTBModel(
|
||||
vocab_size=10000,
|
||||
embedding_dim=650,
|
||||
hidden_dim=650,
|
||||
num_layers=2,
|
||||
dropout_ratio=0.5,
|
||||
use_cudnn_rnn=use_cudnn_rnn)
|
||||
|
||||
|
||||
def main(_):
|
||||
tfe.enable_eager_execution()
|
||||
|
||||
if not FLAGS.data_path:
|
||||
raise ValueError("Must specify --data-path")
|
||||
corpus = Datasets(FLAGS.data_path)
|
||||
train_data = _divide_into_batches(corpus.train, FLAGS.batch_size)
|
||||
eval_data = _divide_into_batches(corpus.valid, 10)
|
||||
|
||||
have_gpu = tfe.num_gpus() > 0
|
||||
use_cudnn_rnn = not FLAGS.no_use_cudnn_rnn and have_gpu
|
||||
|
||||
with tfe.restore_variables_on_create(
|
||||
tf.train.latest_checkpoint(FLAGS.logdir)):
|
||||
with tf.device("/device:GPU:0" if have_gpu else None):
|
||||
# Make learning_rate a Variable so it can be included in the checkpoint
|
||||
# and we can resume training with the last saved learning_rate.
|
||||
learning_rate = tfe.Variable(20.0, name="learning_rate")
|
||||
sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy())
|
||||
model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim,
|
||||
FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout,
|
||||
use_cudnn_rnn)
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
|
||||
|
||||
best_loss = None
|
||||
for _ in range(FLAGS.epoch):
|
||||
train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip)
|
||||
eval_loss = evaluate(model, eval_data)
|
||||
if not best_loss or eval_loss < best_loss:
|
||||
if FLAGS.logdir:
|
||||
tfe.Saver(model.trainable_weights + [learning_rate]).save(
|
||||
os.path.join(FLAGS.logdir, "ckpt"))
|
||||
best_loss = eval_loss
|
||||
else:
|
||||
learning_rate.assign(learning_rate / 4.0)
|
||||
sys.stderr.write("eval_loss did not reduce in this epoch, "
|
||||
"changing learning rate to %f for the next epoch\n" %
|
||||
learning_rate.numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Data directory of the Penn Treebank dataset from "
|
||||
"http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz")
|
||||
parser.add_argument(
|
||||
"--logdir", type=str, default="", help="Directory for checkpoint.")
|
||||
parser.add_argument(
|
||||
"--epoch", type=int, default=20, help="Number of epoches.")
|
||||
parser.add_argument("--batch-size", type=int, default=20, help="Batch size.")
|
||||
parser.add_argument(
|
||||
"--seq-len", type=int, default=35, help="Sequence length.")
|
||||
parser.add_argument(
|
||||
"--embedding-dim", type=int, default=200, help="Embedding dimension.")
|
||||
parser.add_argument(
|
||||
"--hidden-dim", type=int, default=200, help="Hidden layer dimension.")
|
||||
parser.add_argument(
|
||||
"--num-layers", type=int, default=2, help="Number of RNN layers.")
|
||||
parser.add_argument(
|
||||
"--dropout", type=float, default=0.2, help="Drop out ratio.")
|
||||
parser.add_argument(
|
||||
"--clip", type=float, default=0.25, help="Gradient clipping ratio.")
|
||||
parser.add_argument(
|
||||
"--no-use-cudnn-rnn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the fast CuDNN RNN (when no gpu)")
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for PTBModel used for graph construction."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gc
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.eager.python.examples.rnn_ptb import rnn_ptb
|
||||
|
||||
|
||||
class PTBTest(tf.test.TestCase):
|
||||
|
||||
def testTrain(self):
|
||||
batch_size = 20
|
||||
sequence_length = 35
|
||||
with tf.Graph().as_default(), tf.device(tf.test.gpu_device_name()):
|
||||
inputs_ph = tf.placeholder(tf.int64, [sequence_length, batch_size],
|
||||
"inputs")
|
||||
labels_ph = tf.placeholder(tf.int64, [sequence_length, batch_size],
|
||||
"labels")
|
||||
|
||||
inputs = np.ones(inputs_ph.shape.as_list(), dtype=np.int64)
|
||||
labels = np.ones(labels_ph.shape.as_list(), dtype=np.int64)
|
||||
|
||||
model = rnn_ptb.small_model(tf.test.is_gpu_available())
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
loss = rnn_ptb.loss_fn(model, inputs_ph, labels_ph, training=True)
|
||||
grads = rnn_ptb.clip_gradients(optimizer.compute_gradients(loss), 0.25)
|
||||
train_op = optimizer.apply_gradients(grads)
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
sess.run(train_op, feed_dict={inputs_ph: inputs, labels_ph: labels})
|
||||
sess.run(
|
||||
[train_op, loss], feed_dict={
|
||||
inputs_ph: inputs,
|
||||
labels_ph: labels
|
||||
})
|
||||
|
||||
|
||||
class PTBBenchmark(tf.test.Benchmark):
|
||||
|
||||
BATCH_SIZE = 20
|
||||
SEQ_LEN = 35
|
||||
|
||||
def _report(self, label, start, num_iters, device, batch_size):
|
||||
wall_time = (time.time() - start) / num_iters
|
||||
dev = "cpu" if "cpu" in device.lower() else "gpu"
|
||||
name = "%s_%s_batch_%d" % (label, dev, batch_size)
|
||||
examples_per_sec = batch_size / wall_time
|
||||
self.report_benchmark(
|
||||
iters=num_iters,
|
||||
wall_time=wall_time,
|
||||
name=name,
|
||||
extras={
|
||||
"examples_per_sec": examples_per_sec
|
||||
})
|
||||
|
||||
def _benchmark_apply(self, label, model):
|
||||
num_iters = 100
|
||||
num_warmup = 10
|
||||
dataset = tf.data.Dataset.from_tensors(
|
||||
tf.ones(
|
||||
[PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE],
|
||||
dtype=tf.int64)).repeat(num_iters + num_warmup)
|
||||
inputs = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with tf.device(tf.test.gpu_device_name()):
|
||||
outputs = model(inputs, training=True)
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
for _ in range(num_warmup):
|
||||
sess.run(outputs)
|
||||
gc.collect()
|
||||
|
||||
start = time.time()
|
||||
for _ in range(num_iters):
|
||||
sess.run(outputs)
|
||||
self._report(label, start, num_iters,
|
||||
tf.test.gpu_device_name(), PTBBenchmark.BATCH_SIZE)
|
||||
|
||||
def benchmark_apply_small(self):
|
||||
self._benchmark_apply("graph_apply_small", rnn_ptb.small_model(False))
|
||||
|
||||
def benchmark_apply_large(self):
|
||||
self._benchmark_apply("graph_apply_large", rnn_ptb.large_model(False))
|
||||
|
||||
def benchmark_cudnn_apply_small(self):
|
||||
if not tf.test.is_gpu_available():
|
||||
return
|
||||
self._benchmark_apply("graph_cudnn_apply_small", rnn_ptb.small_model(True))
|
||||
|
||||
def benchmark_cudnn_apply_large(self):
|
||||
if not tf.test.is_gpu_available():
|
||||
return
|
||||
self._benchmark_apply("graph_cudnn_apply_large", rnn_ptb.large_model(True))
|
||||
|
||||
def _benchmark_train(self, label, model):
|
||||
num_iters = 100
|
||||
num_warmup = 10
|
||||
dataset = tf.data.Dataset.from_tensors(
|
||||
tf.ones(
|
||||
[PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE],
|
||||
dtype=tf.int64)).repeat(num_iters + num_warmup)
|
||||
# inputs and labels have the same shape
|
||||
dataset = tf.data.Dataset.zip((dataset, dataset))
|
||||
(inputs, labels) = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with tf.device(tf.test.gpu_device_name()):
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
loss = rnn_ptb.loss_fn(model, inputs, labels, training=True)
|
||||
grads = rnn_ptb.clip_gradients(optimizer.compute_gradients(loss), 0.25)
|
||||
train_op = optimizer.apply_gradients(grads)
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
for _ in range(num_warmup):
|
||||
sess.run(train_op)
|
||||
gc.collect()
|
||||
start = time.time()
|
||||
for _ in range(num_iters):
|
||||
sess.run(train_op)
|
||||
self._report(label, start, num_iters,
|
||||
tf.test.gpu_device_name(), PTBBenchmark.BATCH_SIZE)
|
||||
|
||||
def benchmark_train_small(self):
|
||||
self._benchmark_train("graph_train_small", rnn_ptb.small_model(False))
|
||||
|
||||
def benchmark_train_large(self):
|
||||
self._benchmark_train("graph_train_large", rnn_ptb.large_model(False))
|
||||
|
||||
def benchmark_cudnn_train_small(self):
|
||||
if not tf.test.is_gpu_available():
|
||||
return
|
||||
self._benchmark_train("graph_cudnn_train_small", rnn_ptb.small_model(True))
|
||||
|
||||
def benchmark_cudnn_train_large(self):
|
||||
if not tf.test.is_gpu_available():
|
||||
return
|
||||
self._benchmark_train("graph_cudnn_train_large", rnn_ptb.large_model(True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
154
tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_test.py
Normal file
154
tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_test.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for PTBModel with eager execution enabled."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gc
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.eager.python import tfe
|
||||
from tensorflow.contrib.eager.python.examples.rnn_ptb import rnn_ptb
|
||||
|
||||
|
||||
def device():
|
||||
return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0"
|
||||
|
||||
|
||||
class PTBTest(tf.test.TestCase):
|
||||
|
||||
def testTrain(self):
|
||||
model = rnn_ptb.small_model(tfe.num_gpus() > 0)
|
||||
sequence_length = 35
|
||||
data = np.ones([4 * sequence_length, 20], dtype=np.int64)
|
||||
with tf.device(device()):
|
||||
optimizer = tf.train.GradientDescentOptimizer(1.0)
|
||||
# Train two epochs
|
||||
rnn_ptb.train(model, optimizer, data, sequence_length, 0.25)
|
||||
rnn_ptb.train(model, optimizer, data, sequence_length, 0.25)
|
||||
|
||||
def testApply(self):
|
||||
model = rnn_ptb.small_model(tfe.num_gpus() > 0)
|
||||
with tf.device(device()):
|
||||
model(tf.ones([35, 20], dtype=tf.int64), training=False)
|
||||
|
||||
|
||||
def force_gpu_sync():
|
||||
if tfe.num_gpus():
|
||||
tf.constant(1).gpu().cpu()
|
||||
|
||||
|
||||
class PTBBenchmark(tf.test.Benchmark):
|
||||
|
||||
BATCH_SIZE = 20
|
||||
SEQ_LEN = 35
|
||||
|
||||
def _report(self, label, start, num_iters, dev, batch_size):
|
||||
wall_time = (time.time() - start) / num_iters
|
||||
dev = "cpu" if "cpu" in dev.lower() else "gpu"
|
||||
name = "%s_%s_batch_%d" % (label, dev, batch_size)
|
||||
examples_per_sec = batch_size / wall_time
|
||||
self.report_benchmark(
|
||||
iters=num_iters,
|
||||
wall_time=wall_time,
|
||||
name=name,
|
||||
extras={
|
||||
"examples_per_sec": examples_per_sec
|
||||
})
|
||||
|
||||
def _benchmark_apply(self, label, model):
|
||||
with tf.device(device()):
|
||||
sequence_batch = tf.ones(
|
||||
[PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], dtype=tf.int64)
|
||||
|
||||
for _ in range(10): # Warmup
|
||||
model(sequence_batch, training=False).cpu()
|
||||
gc.collect()
|
||||
|
||||
start = time.time()
|
||||
iters = 100
|
||||
for _ in range(iters):
|
||||
model(sequence_batch, training=False).cpu()
|
||||
self._report(label, start, iters, device(), int(sequence_batch.shape[1]))
|
||||
|
||||
def benchmark_apply_small(self):
|
||||
self._benchmark_apply("eager_apply_small", rnn_ptb.small_model(False))
|
||||
|
||||
def benchmark_apply_large(self):
|
||||
self._benchmark_apply("eager_apply_large", rnn_ptb.large_model(False))
|
||||
|
||||
def benchmark_cudnn_apply_small(self):
|
||||
if not tfe.num_gpus():
|
||||
return
|
||||
self._benchmark_apply("eager_cudnn_apply_small", rnn_ptb.small_model(True))
|
||||
|
||||
def benchmark_cudnn_apply_large(self):
|
||||
if not tfe.num_gpus():
|
||||
return
|
||||
self._benchmark_apply("eager_cudnn_apply_large", rnn_ptb.large_model(True))
|
||||
|
||||
def _benchmark_train(self, label, model):
|
||||
with tf.device(device()):
|
||||
optimizer = tf.train.GradientDescentOptimizer(1.)
|
||||
|
||||
def model_loss(inputs, targets):
|
||||
return rnn_ptb.loss_fn(model, inputs, targets, training=True)
|
||||
|
||||
grads = tfe.implicit_gradients(model_loss)
|
||||
|
||||
sequence_batch = tf.ones(
|
||||
[PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], dtype=tf.int64)
|
||||
|
||||
def step():
|
||||
optimizer.apply_gradients(
|
||||
rnn_ptb.clip_gradients(grads(sequence_batch, sequence_batch), 0.25))
|
||||
|
||||
for _ in range(10): # Warmup
|
||||
step()
|
||||
force_gpu_sync()
|
||||
gc.collect()
|
||||
|
||||
start = time.time()
|
||||
iters = 100
|
||||
for _ in range(iters):
|
||||
step()
|
||||
force_gpu_sync()
|
||||
self._report(label, start, iters, device(), int(sequence_batch.shape[1]))
|
||||
|
||||
def benchmark_train_small(self):
|
||||
self._benchmark_train("eager_train_small", rnn_ptb.small_model(False))
|
||||
|
||||
def benchmark_train_large(self):
|
||||
self._benchmark_train("eager_train_large", rnn_ptb.large_model(False))
|
||||
|
||||
def benchmark_cudnn_train_small(self):
|
||||
if not tfe.num_gpus():
|
||||
return
|
||||
self._benchmark_train("eager_cudnn_train_small", rnn_ptb.small_model(True))
|
||||
|
||||
def benchmark_cudnn_train_large(self):
|
||||
if not tfe.num_gpus():
|
||||
return
|
||||
self._benchmark_train("eager_cudnn_train_large", rnn_ptb.large_model(True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tfe.enable_eager_execution()
|
||||
tf.test.main()
|
||||
899
tensorflow/contrib/eager/python/g3doc/guide.md
Normal file
899
tensorflow/contrib/eager/python/g3doc/guide.md
Normal file
|
|
@ -0,0 +1,899 @@
|
|||
# TensorFlow Eager Execution
|
||||
|
||||
## What is this?
|
||||
|
||||
Eager execution is a feature that makes TensorFlow execute operations
|
||||
immediately: concrete values are returned, instead of a computational graph to
|
||||
be executed later.
|
||||
|
||||
As a result, enabling eager execution provides:
|
||||
|
||||
- A [NumPy](http://www.numpy.org/)-like library for numerical computation with
|
||||
support for GPU acceleration and automatic differentiation.
|
||||
- A flexible platform for machine learning research and experimentation.
|
||||
|
||||
Eager execution is under active development. This guide walks through an
|
||||
alpha/preview release. In particular, not all TensorFlow APIs currently work
|
||||
with eager execution enabled, and some models may be slow to execute, compared
|
||||
to models defined without using eager execution.
|
||||
|
||||
## Installation
|
||||
|
||||
Eager execution is **not** included in the latest release (version 1.4) of
|
||||
TensorFlow. To use it, you will need to [build TensorFlow from
|
||||
source](https://www.tensorflow.org/install/install_sources) or install the
|
||||
nightly builds.
|
||||
|
||||
For example, the nightly builds can be installed using `pip`:
|
||||
|
||||
- `pip install tf-nightly` (for CPU-only TensorFlow)
|
||||
- `pip install tf-nightly-gpu` (for GPU-enabled TensorFlow)
|
||||
|
||||
Or using `docker`, with [Jupyter Notebook](http://jupyter.org/) support:
|
||||
|
||||
```sh
|
||||
# For CPU-only TensorFlow
|
||||
docker pull tensorflow/tensorflow:nightly
|
||||
docker run -it -p 8888:8888 tensorflow/tensorflow:nightly
|
||||
|
||||
# For GPU-enabled TensorFlow:
|
||||
# (Requires https://github.com/NVIDIA/nvidia-docker)
|
||||
nvidia-docker pull tensorflow/tensorflow:nightly-gpu
|
||||
nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
|
||||
With TensorFlow installed, eager execution is enabled via a single call:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
|
||||
import tensorflow.contrib.eager as tfe
|
||||
|
||||
tfe.enable_eager_execution()
|
||||
```
|
||||
|
||||
Enabling eager execution changes how TensorFlow functions behave (in particular,
|
||||
`Tensor` objects will reference concrete values instead of being symbolic
|
||||
handles to nodes in a computational graph). As a result, eager execution should
|
||||
be enabled at the beginning of a program and cannot be disabled afterwards in
|
||||
the same program.
|
||||
|
||||
Code examples in the rest of this guide assume that eager execution has been
|
||||
enabled.
|
||||
|
||||
## A library for numerical computation
|
||||
|
||||
A significant fraction of the [TensorFlow
|
||||
API](https://www.tensorflow.org/api_docs/python/) consists of numerical
|
||||
operations:
|
||||
[arithmetic operations](https://www.tensorflow.org/api_docs/python/tf/matmul),
|
||||
[matrix operations](https://www.tensorflow.org/api_docs/python/tf/matmul),
|
||||
[linear algebra operations](https://www.tensorflow.org/api_docs/python/tf/linalg),
|
||||
etc.
|
||||
|
||||
With eager execution enabled, these operations consume and return
|
||||
multi-dimensional arrays as `Tensor` objects, similar to NumPy
|
||||
[`ndarray`s](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ndarray.html).
|
||||
For example:
|
||||
|
||||
```python
|
||||
# Multiply two 2x2 matrices
|
||||
x = tf.matmul([[1, 2],
|
||||
[3, 4]],
|
||||
[[4, 5],
|
||||
[6, 7]])
|
||||
# Add one to each element
|
||||
# (tf.add supports broadcasting)
|
||||
y = tf.add(x, 1)
|
||||
|
||||
# Create a random random 5x3 matrix
|
||||
z = tf.random_uniform([5, 3])
|
||||
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
tf.Tensor(
|
||||
[[16 19]
|
||||
[36 43]], shape=(2, 2), dtype=int32)
|
||||
tf.Tensor(
|
||||
[[17 20]
|
||||
[37 44]], shape=(2, 2), dtype=int32)
|
||||
tf.Tensor(
|
||||
[[ 0.25058532 0.0929395 0.54113817]
|
||||
[ 0.3108716 0.93350542 0.84909797]
|
||||
[ 0.53081679 0.12788558 0.01767385]
|
||||
[ 0.29725885 0.33540785 0.83588314]
|
||||
[ 0.38877153 0.39720535 0.78914213]], shape=(5, 3), dtype=float32)
|
||||
```
|
||||
|
||||
For convenience, these operations can also be triggered via operator overloading
|
||||
of the `Tensor` object. For example, the `+` operator is equivalent to `tf.add`,
|
||||
`-` to `tf.subtract`, `*` to `tf.multiply`, etc.:
|
||||
|
||||
```python
|
||||
x = (tf.ones([1], dtype=tf.float32) + 1) * 2 - 1
|
||||
print(x)
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
tf.Tensor([ 3.], shape=(1,), dtype=float32)
|
||||
```
|
||||
|
||||
### Converting to and from NumPy
|
||||
|
||||
The operations above automatically convert Python objects (like lists of
|
||||
numbers) and NumPy arrays to `Tensor` objects. `Tensor` objects can also be used
|
||||
as NumPy arrays by numpy operations.
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
x = tf.add(1, 1) # tf.Tensor with a value of 2
|
||||
y = tf.add(np.array(1), np.array(1)) # tf.Tensor with a value of 2
|
||||
z = np.multiply(x, y) # numpy.int64 with a value of 4
|
||||
```
|
||||
|
||||
Alternatively, they can be explicitly converted using
|
||||
[`tf.constant`](https://www.tensorflow.org/api_docs/python/tf/constant), as
|
||||
shown in the next example.
|
||||
|
||||
Conversely, you can call the `numpy()` method of a `Tensor` object' to obtain
|
||||
its NumPy `ndarray` value. For example:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
np_x = np.array(2., dtype=np.float32)
|
||||
x = tf.constant(np_x)
|
||||
|
||||
py_y = 3.
|
||||
y = tf.constant(py_y)
|
||||
|
||||
z = x + y + 1
|
||||
|
||||
print(z)
|
||||
print(z.numpy())
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
tf.Tensor(6.0, shape=(), dtype=float32)
|
||||
6.0
|
||||
```
|
||||
|
||||
### GPU acceleration
|
||||
|
||||
Many TensorFlow operations support GPU acceleration. With eager execution
|
||||
enabled, [computation is *not* automatically
|
||||
offloaded](https://www.tensorflow.org/tutorials/using_gpu) to GPUs. Instead, you
|
||||
must explicitly specify when GPUs should be used.
|
||||
|
||||
The simplest way to do this is to enclose your computation in a `with
|
||||
tf.device('/gpu:0')` block. Also of interest is the `tfe.num_gpus()` function,
|
||||
which returns the number of available GPUs.
|
||||
|
||||
For example, consider this snippet to measure the time to multiply two 1000x1000
|
||||
matrices on CPU:
|
||||
|
||||
```python
|
||||
import time
|
||||
|
||||
def measure(x):
|
||||
# The very first time a GPU is used by TensorFlow, it is initialized.
|
||||
# So exclude the first run from timing.
|
||||
tf.matmul(x, x)
|
||||
|
||||
start = time.time()
|
||||
for i in range(10):
|
||||
tf.matmul(x, x)
|
||||
end = time.time()
|
||||
|
||||
return "Took %s seconds to multiply a %s matrix by itself 10 times" % (end - start, x.shape)
|
||||
|
||||
# Run on CPU:
|
||||
with tf.device("/cpu:0"):
|
||||
print("CPU: %s" % measure(tf.random_normal([1000, 1000])))
|
||||
|
||||
# If a GPU is available, run on GPU:
|
||||
if tfe.num_gpus() > 0:
|
||||
with tf.device("/gpu:0"):
|
||||
print("GPU: %s" % measure(tf.random_normal([1000, 1000])))
|
||||
```
|
||||
|
||||
Output (exact numbers will depend on the characteristics of the hardware):
|
||||
|
||||
```python
|
||||
CPU: Took 0.145531892776 seconds to multiply a (1000, 1000) matrix by itself 10 times
|
||||
GPU: Took 0.000458955764771 seconds to multiply a (1000, 1000) matrix by itself 10 times
|
||||
```
|
||||
|
||||
Alternatively, methods on the `Tensor` object can be used to explicitly copy the
|
||||
`Tensor` to a different device. Operations are typically executed on the device
|
||||
on which the inputs are placed. For example:
|
||||
|
||||
```python
|
||||
x = tf.random_normal([10, 10])
|
||||
|
||||
x_gpu0 = x.gpu()
|
||||
x_cpu = x.cpu()
|
||||
|
||||
_ = tf.matmul(x_cpu, x_cpu) # Runs on CPU
|
||||
_ = tf.matmul(x_gpu0, x_gpu0) # Runs on GPU:0
|
||||
|
||||
if tfe.num_gpus() > 1:
|
||||
x_gpu1 = x.gpu(1)
|
||||
_ = tf.matmul(x_gpu1, x_gpu1) # Runs on GPU:1
|
||||
```
|
||||
|
||||
### Automatic Differentiation
|
||||
|
||||
[Automatic
|
||||
differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) is
|
||||
very useful when implementing many machine learning algorithms (e.g.,
|
||||
[backpropagation](https://en.wikipedia.org/wiki/Backpropagation) for training
|
||||
neural networks). For this purpose, TensorFlow eager execution provides an
|
||||
[autograd](https://github.com/HIPS/autograd)-style API for automatic
|
||||
differentiation. Specifically, the functions:
|
||||
|
||||
- `tfe.gradients_function(f)`: Returns a Python function that computes the
|
||||
derivatives of the Python function `f` with respect to its arguments. `f`
|
||||
must return a scalar value. When the returned function is invoked, it
|
||||
returns a list of `Tensor` objects (one element for each argument of `f`).
|
||||
- `tfe.value_and_gradients_function(f)`: Similar to `tfe.gradients_function`,
|
||||
except that when the returned function is invoked, it returns the value of
|
||||
`f` in addition to the list of derivatives of `f` with respect to its
|
||||
arguments.
|
||||
|
||||
These functions naturally apply to higher order differentiation as well. For
|
||||
example:
|
||||
|
||||
```python
|
||||
def f(x):
|
||||
return tf.multiply(x, x) # Or x * x
|
||||
assert 9 == f(3.).numpy()
|
||||
|
||||
df = tfe.gradients_function(f)
|
||||
assert 6 == df(3.)[0].numpy()
|
||||
|
||||
# Second order deriviative.
|
||||
d2f = tfe.gradients_function(lambda x: df(x)[0])
|
||||
assert 2 == d2f(3.)[0].numpy()
|
||||
|
||||
# Third order derivative.
|
||||
d3f = tfe.gradients_function(lambda x : d2f(x)[0])
|
||||
assert 0 == d3f(3.)[0].numpy()
|
||||
```
|
||||
|
||||
These functions can be used to train models. For example, consider the following
|
||||
simple linear regression model:
|
||||
|
||||
```python
|
||||
def prediction(input, weight, bias):
|
||||
return input * weight + bias
|
||||
|
||||
# A toy dataset of points around 3 * x + 2
|
||||
NUM_EXAMPLES = 1000
|
||||
training_inputs = tf.random_normal([NUM_EXAMPLES])
|
||||
noise = tf.random_normal([NUM_EXAMPLES])
|
||||
training_outputs = training_inputs * 3 + 2 + noise
|
||||
|
||||
# A loss function: Mean-squared error
|
||||
def loss(weight, bias):
|
||||
error = prediction(training_inputs, weight, bias) - training_outputs
|
||||
return tf.reduce_mean(tf.square(error))
|
||||
|
||||
# Function that returns the the derivative of loss with respect to
|
||||
# weight and bias
|
||||
grad = tfe.gradients_function(loss)
|
||||
|
||||
# Train for 200 steps (starting from some random choice for W and B, on the same
|
||||
# batch of data).
|
||||
W = 5.
|
||||
B = 10.
|
||||
learning_rate = 0.01
|
||||
print("Initial loss: %f" % loss(W, B).numpy())
|
||||
for i in range(200):
|
||||
(dW, dB) = grad(W, B)
|
||||
W -= dW * learning_rate
|
||||
B -= dB * learning_rate
|
||||
if i % 20 == 0:
|
||||
print("Loss at step %d: %f" % (i, loss(W, B).numpy()))
|
||||
print("Final loss: %f" % loss(W, B).numpy())
|
||||
print("W, B = %f, %f" % (W.numpy(), B.numpy()))
|
||||
```
|
||||
|
||||
Output: (the exact numbers may vary depending on the randomness in noise)
|
||||
|
||||
```
|
||||
Initial loss: 66.730003
|
||||
Loss at step 0: 64.200096
|
||||
Loss at step 20: 29.872814
|
||||
Loss at step 40: 14.233772
|
||||
Loss at step 60: 7.090570
|
||||
Loss at step 80: 3.819887
|
||||
Loss at step 100: 2.318821
|
||||
Loss at step 120: 1.628385
|
||||
Loss at step 140: 1.310142
|
||||
Loss at step 160: 1.163167
|
||||
Loss at step 180: 1.095162
|
||||
Final loss: 1.064711
|
||||
W, B = 3.094944, 2.161383
|
||||
```
|
||||
|
||||
To utilize the GPU, place the code above within a `with tf.device("/gpu:0"):`
|
||||
block. (However, this particular model, with only two floating point parameters,
|
||||
is unlikely to benefit from GPU acceleration.)
|
||||
|
||||
### Customizing gradients
|
||||
|
||||
One may want to define custom gradients for an operation, or for a function.
|
||||
This may be useful for multiple reasons, including providing a more efficient
|
||||
or more [numerically stable](https://en.wikipedia.org/wiki/Numerical_stability)
|
||||
gradient for a sequence of operations.
|
||||
|
||||
For example, consider the function `log(1 + e^x)`, which commonly occurs in the
|
||||
computation of cross entropy and log likelihoods.
|
||||
|
||||
```python
|
||||
def log1pexp(x):
|
||||
return tf.log(1 + tf.exp(x))
|
||||
grad_log1pexp = tfe.gradients_function(log1pexp)
|
||||
|
||||
# Works fine at x = 0.
|
||||
assert 0.5 == float(grad_log1pexp(0.)[0])
|
||||
|
||||
# Returns a `nan` at x = 100 due to numerical instability.
|
||||
import math
|
||||
assert math.isnan(float(grad_log1pexp(100.)[0]))
|
||||
```
|
||||
|
||||
We can define a custom gradient for the above function that analytically
|
||||
simplifies the gradient expression.
|
||||
|
||||
```python
|
||||
@tfe.custom_gradient
|
||||
def log1pexp(x):
|
||||
e = tf.exp(x)
|
||||
def grad(dy):
|
||||
return dy * (1 - 1 / (1 + e))
|
||||
return tf.log(1 + e), grad
|
||||
grad_log1pexp = tfe.gradients_function(log1pexp)
|
||||
|
||||
# Works as before at x = 0.
|
||||
assert 0.5 == float(grad_log1pexp(0.)[0])
|
||||
|
||||
# But now works at x = 100 as well.
|
||||
assert 1.0 == float(grad_log1pexp(100.)[0])
|
||||
```
|
||||
Also notice how the gradient function implementation reuses an expression
|
||||
(`tf.exp(x)`) computed during the forward pass, hence making the gradient
|
||||
computation more efficient by avoiding redundant computation.
|
||||
|
||||
## Building and training models
|
||||
|
||||
In practice, your computation may have many parameters to be optimized (by
|
||||
computing derivatives). Encapsulating them into re-usable classes/objects
|
||||
makes the code easier to follow than writing a single top-level function with
|
||||
many arguments.
|
||||
|
||||
In fact, eager execution encourages use of the [Keras](https://keras.io)-style
|
||||
"Layer" classes in the
|
||||
[`tf.layers`](https://www.tensorflow.org/versions/master/api_docs/python/tf/layers)
|
||||
module.
|
||||
|
||||
Furthermore, you may want to apply more sophisticated techniques to compute
|
||||
parameter updates, such as those in
|
||||
[`tf.train.Optimizer`](https://www.tensorflow.org/api_guides/python/train#Optimizers)
|
||||
implementations.
|
||||
|
||||
This next section walks through using the same `Optimizer` and `Layer` APIs used
|
||||
to build trainable TensorFlow graphs in an environment where eager execution is
|
||||
enabled.
|
||||
|
||||
### Variables and Optimizers
|
||||
|
||||
`tfe.Variable` objects store mutable `Tensor` values that can be accessed during
|
||||
training, making automatic differentiation easier. In particular, parameters of
|
||||
a model can be encapsulated in Python classes as variables.
|
||||
|
||||
`tfe.gradients_function(f)` introduced earlier computes the derivatives of `f`
|
||||
with respect to its arguments. However, it requires all parameters of interest
|
||||
to be arguments of `f`, which becomes cumbersome when `f` depends on a large
|
||||
number of trainable parameters.
|
||||
|
||||
`tfe.implicit_gradients` is an alternative function with some useful properties:
|
||||
|
||||
- It computes the derivatives of `f` with respect to all the `tfe.Variable`s
|
||||
used by `f`.
|
||||
- When the returned function is invoked, it returns a list of
|
||||
(gradient value, Variable object) tuples.
|
||||
|
||||
Representing model parameters as `Variable` objects, along with the use of
|
||||
`tfe.implicit_gradients`, typically results in better encapsulation. For
|
||||
example, the linear regression model described above can be written into a
|
||||
class:
|
||||
|
||||
```python
|
||||
class Model(object):
|
||||
def __init__(self):
|
||||
self.W = tfe.Variable(5., name='weight')
|
||||
self.B = tfe.Variable(10., name='bias')
|
||||
|
||||
def predict(self, inputs):
|
||||
return inputs * self.W + self.B
|
||||
|
||||
|
||||
# The loss function to be optimized
|
||||
def loss(model, inputs, targets):
|
||||
error = model.predict(inputs) - targets
|
||||
return tf.reduce_mean(tf.square(error))
|
||||
|
||||
# A toy dataset of points around 3 * x + 2
|
||||
NUM_EXAMPLES = 1000
|
||||
training_inputs = tf.random_normal([NUM_EXAMPLES])
|
||||
noise = tf.random_normal([NUM_EXAMPLES])
|
||||
training_outputs = training_inputs * 3 + 2 + noise
|
||||
|
||||
# Define:
|
||||
# 1. A model
|
||||
# 2. Derivatives of a loss function with respect to model parameters
|
||||
# 3. A strategy for updating the variables based on the derivatives
|
||||
model = Model()
|
||||
grad = tfe.implicit_gradients(loss)
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
|
||||
|
||||
# The training loop
|
||||
print("Initial loss: %f" %
|
||||
loss(model, training_inputs, training_outputs).numpy())
|
||||
for i in range(201):
|
||||
optimizer.apply_gradients(grad(model, training_inputs, training_outputs))
|
||||
if i % 20 == 0:
|
||||
print("Loss at step %d: %f" %
|
||||
(i, loss(model, training_inputs, training_outputs).numpy()))
|
||||
print("Final loss: %f" % loss(model, training_inputs, training_outputs).numpy())
|
||||
print("W, B = %s, %s" % (model.W.numpy(), model.B.numpy()))
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
Initial loss: 69.693184
|
||||
Loss at step 0: 66.987854
|
||||
Loss at step 20: 30.553387
|
||||
Loss at step 40: 14.250237
|
||||
Loss at step 60: 6.955020
|
||||
Loss at step 80: 3.690550
|
||||
Loss at step 100: 2.229739
|
||||
Loss at step 120: 1.576032
|
||||
Loss at step 140: 1.283496
|
||||
Loss at step 160: 1.152584
|
||||
Loss at step 180: 1.093999
|
||||
Final loss: 1.067780
|
||||
W, B = 3.0114281, 2.0865183
|
||||
```
|
||||
|
||||
Using `implicit_gradients` avoids the need to provide all the trainable
|
||||
parameters of the model as arguments to the `loss` function.
|
||||
|
||||
### Using Keras and the Layers API
|
||||
|
||||
[Keras](https://keras.io) is a popular API for defining model structures. The
|
||||
[`tf.keras.layers`](https://www.tensorflow.org/versions/master/api_docs/python/tf/keras/layers)
|
||||
module provides a set of building blocks for models and is implemented using the
|
||||
`tf.layers.Layer` subclasses in the
|
||||
[`tf.layers`](https://www.tensorflow.org/versions/master/api_docs/python/tf/layers)
|
||||
module. We encourage the use of these same building blocks when using
|
||||
TensorFlow's eager execution feature. For example, the very same linear
|
||||
regression model can be built using `tf.layers.Dense`:
|
||||
|
||||
```python
|
||||
class Model(object):
|
||||
def __init__(self):
|
||||
self.layer = tf.layers.Dense(1)
|
||||
|
||||
def predict(self, inputs):
|
||||
return self.layer(inputs)
|
||||
```
|
||||
|
||||
The `tf.layers` API makes it more convenient to define more sophisticated
|
||||
models. For example, the following will train an MNIST model:
|
||||
|
||||
```python
|
||||
class MNISTModel(object):
|
||||
def __init__(self, data_format):
|
||||
# 'channels_first' is typically faster on GPUs
|
||||
# while 'channels_last' is typically faster on CPUs.
|
||||
# See: https://www.tensorflow.org/performance/performance_guide#data_formats
|
||||
if data_format == 'channels_first':
|
||||
self._input_shape = [-1, 1, 28, 28]
|
||||
else:
|
||||
self._input_shape = [-1, 28, 28, 1]
|
||||
self.conv1 = tf.layers.Conv2D(32, 5,
|
||||
padding='same',
|
||||
activation=tf.nn.relu,
|
||||
data_format=data_format)
|
||||
self.max_pool2d = tf.layers.MaxPooling2D(
|
||||
(2, 2), (2, 2), padding='same', data_format=data_format)
|
||||
self.conv2 = tf.layers.Conv2D(64, 5,
|
||||
padding='same',
|
||||
activation=tf.nn.relu,
|
||||
data_format=data_format)
|
||||
self.dense1 = tf.layers.Dense(1024, activation=tf.nn.relu)
|
||||
self.dropout = tf.layers.Dropout(0.5)
|
||||
self.dense2 = tf.layers.Dense(10)
|
||||
|
||||
def predict(self, inputs):
|
||||
x = tf.reshape(inputs, self._input_shape)
|
||||
x = self.max_pool2d(self.conv1(x))
|
||||
x = self.max_pool2d(self.conv2(x))
|
||||
x = tf.layers.flatten(x)
|
||||
x = self.dropout(self.dense1(x))
|
||||
return self.dense2(x)
|
||||
|
||||
def loss(model, inputs, targets):
|
||||
return tf.reduce_mean(
|
||||
tf.nn.softmax_cross_entropy_with_logits(
|
||||
logits=model.predict(inputs), labels=targets))
|
||||
|
||||
|
||||
# Load the training and validation data
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
data = input_data.read_data_sets("./mnist_data", one_hot=True)
|
||||
|
||||
# Train
|
||||
device = "gpu:0" if tfe.num_gpus() else "cpu:0"
|
||||
model = MNISTModel('channels_first' if tfe.num_gpus() else 'channels_last')
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
|
||||
grad = tfe.implicit_gradients(loss)
|
||||
for i in range(20001):
|
||||
with tf.device(device):
|
||||
(inputs, targets) = data.train.next_batch(50)
|
||||
optimizer.apply_gradients(grad(model, inputs, targets))
|
||||
if i % 100 == 0:
|
||||
print("Step %d: Loss on training set : %f" %
|
||||
(i, loss(model, inputs, targets).numpy()))
|
||||
print("Loss on test set: %f" % loss(model, data.test.images, data.test.labels).numpy())
|
||||
```
|
||||
|
||||
For a more complete example, see
|
||||
[`tensorflow/contrib/eager/python/examples/mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py)
|
||||
|
||||
### Checkpointing trained variables
|
||||
|
||||
TensorFlow Variables (`tfe.Variable`) provides a way to represent shared,
|
||||
persistent state of your model. The `tfe.Saver` class (which is a thin wrapper
|
||||
over the
|
||||
[`tf.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/train/Saver)
|
||||
class) provides a means to save and restore variables to and from _checkpoints_.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
# Create variables.
|
||||
x = tfe.Variable(10., name='x')
|
||||
y = tfe.Variable(5., name='y')
|
||||
|
||||
# Create a Saver.
|
||||
saver = tfe.Saver([x, y])
|
||||
|
||||
# Assign new values to the variables and save.
|
||||
x.assign(2.)
|
||||
saver.save('/tmp/ckpt')
|
||||
|
||||
# Change the variable after saving.
|
||||
x.assign(11.)
|
||||
assert 16. == (x + y).numpy() # 11 + 5
|
||||
|
||||
# Restore the values in the checkpoint.
|
||||
saver.restore('/tmp/ckpt')
|
||||
|
||||
assert 7. == (x + y).numpy() # 2 + 5
|
||||
```
|
||||
|
||||
### `tfe.Network`
|
||||
|
||||
You may often want to organize your models using classes, like the `MNISTModel`
|
||||
class described above. We recommend inheriting from the `tfe.Network` class as
|
||||
it provides conveniences like keeping track of all model variables and methods
|
||||
to save and restore from checkpoints.
|
||||
|
||||
Sub-classes of `tfe.Network` may register `Layer`s (like classes in
|
||||
[`tf.layers`](https://www.tensorflow.org/versions/master/api_docs/python/tf/layers),
|
||||
or [Keras
|
||||
layers](https://www.tensorflow.org/versions/master/api_docs/python/tf/keras/layers))
|
||||
using a call to `self.track_layer()` and define the computation in an
|
||||
implementation of `call()`.
|
||||
|
||||
Note that `tf.layers.Layer` objects (like `tf.layers.Dense`) create variables
|
||||
lazily, when the first input is encountered.
|
||||
|
||||
For example, consider the following two-layer neural network:
|
||||
|
||||
```python
|
||||
class TwoLayerNet(tfe.Network):
|
||||
def __init__(self):
|
||||
super(TwoLayerNet, self).__init__()
|
||||
self.layer1 = self.track_layer(
|
||||
tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False))
|
||||
self.layer2 = self.track_layer(tf.layers.Dense(3, use_bias=False))
|
||||
|
||||
def call(self, x):
|
||||
return self.layer2(self.layer1(x))
|
||||
|
||||
net = TwoLayerNet()
|
||||
|
||||
# No variables created yet
|
||||
assert 0 == len(net.variables)
|
||||
|
||||
# They are created on first input:
|
||||
inp = tf.constant([[1.]])
|
||||
|
||||
# Since input is a 1x1 matrix, net.l1 has 2 units and net.l2 has 3 units,
|
||||
# the output is the product of a 1x1 matrix with a 1x2 matrix with a 2x3
|
||||
# matrix.
|
||||
assert [1, 3] == net(inp).shape.as_list() # Invoke net; get output shape.
|
||||
assert 1 == len(net.layer1.variables)
|
||||
assert 1 == len(net.layer2.variables)
|
||||
assert 2 == len(net.variables) # weights for each layer.
|
||||
assert [1, 2] == net.variables[0].shape.as_list() # weights of layer1.
|
||||
assert [2, 3] == net.variables[1].shape.as_list() # weights of layer2.
|
||||
```
|
||||
|
||||
The `tfe.Network` class is itself a sub-class of `tf.layers.Layer`. This allows
|
||||
instances of `tfe.Network` to be embedded in other networks. For example:
|
||||
|
||||
```python
|
||||
class ThreeLayerNet(tfe.Network):
|
||||
def __init__(self):
|
||||
super(ThreeLayerNet, self).__init__()
|
||||
self.a = self.track_layer(TwoLayerNet())
|
||||
self.b = self.track_layer(tf.layers.Dense(4, use_bias=False))
|
||||
|
||||
def call(self, x):
|
||||
return self.b(self.a(x))
|
||||
|
||||
net = ThreeLayerNet()
|
||||
|
||||
assert [1, 4] == net(inp).shape.as_list()
|
||||
assert 3 == len(net.variables)
|
||||
assert [1, 2] == net.variables[0].shape.as_list()
|
||||
assert [2, 3] == net.variables[1].shape.as_list()
|
||||
assert [3, 4] == net.variables[2].shape.as_list()
|
||||
```
|
||||
|
||||
See more examples in
|
||||
[`tensorflow/contrib/eager/python/examples`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples).
|
||||
|
||||
`tfe.Saver` in combination with `tfe.restore_variables_on_create` provides a
|
||||
convenient way to save and load checkpoints without changing the program once
|
||||
the checkpoint has been created. For example, we can set an objective for the
|
||||
output of our network, choose an optimizer, and a location for the checkpoint:
|
||||
|
||||
```python
|
||||
objective = tf.constant([[2., 3., 4., 5.]])
|
||||
optimizer = tf.train.AdamOptimizer(0.01)
|
||||
checkpoint_directory = '/tmp/tfe_example'
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
|
||||
net = ThreeLayerNet()
|
||||
```
|
||||
|
||||
Note that variables have not been created yet. We want them to be restored from
|
||||
a checkpoint, if one exists, so we create them inside a
|
||||
`tfe.restore_variables_on_create` context manager. Then our training loop is the
|
||||
same whether starting training or resuming from a previous checkpoint:
|
||||
|
||||
```python
|
||||
with tfe.restore_variables_on_create(
|
||||
tf.train.latest_checkpoint(checkpoint_directory)):
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
for _ in range(100):
|
||||
loss_fn = lambda: tf.norm(net(inp) - objective)
|
||||
optimizer.minimize(loss_fn, global_step=global_step)
|
||||
if tf.equal(global_step % 20, 0):
|
||||
print("Step %d, output %s" % (global_step.numpy(),
|
||||
net(inp).numpy()))
|
||||
all_variables = (
|
||||
net.variables
|
||||
+ tfe.get_optimizer_variables(optimizer)
|
||||
+ [global_step])
|
||||
# Save the checkpoint.
|
||||
tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step)
|
||||
```
|
||||
|
||||
The first time it runs, `Network` variables are initialized randomly. Then the
|
||||
output is trained to match the objective we've set:
|
||||
|
||||
```
|
||||
Step 20, output [[ 0.03575622 0.29863232 0.03474367 0.24735749]]
|
||||
Step 40, output [[ 0.40646029 0.9856872 0.46851286 0.95358551]]
|
||||
Step 60, output [[ 1.74541104 2.800704 1.79055595 2.74783421]]
|
||||
Step 80, output [[ 2.14977384 3.44340849 3.96120024 5.16242075]]
|
||||
Step 100, output [[ 1.99943113 3.02364397 3.93500996 4.9610076 ]]
|
||||
```
|
||||
|
||||
In subsequent iterations, variables are initialized with the values read from
|
||||
the latest checkpoint. Running the same code again, we continue from where we
|
||||
left off:
|
||||
|
||||
```
|
||||
Step 120, output [[ 1.99234128 3.0271616 3.98732996 4.96401167]]
|
||||
Step 140, output [[ 2.00133467 3.01270437 4.00616646 5.00406504]]
|
||||
Step 160, output [[ 1.99647415 2.9956708 3.99064088 4.99632359]]
|
||||
Step 180, output [[ 2.00699997 3.00904822 4.00706148 5.01193142]]
|
||||
Step 200, output [[ 1.98334622 2.98249531 3.97375059 4.97123432]]
|
||||
```
|
||||
|
||||
|
||||
### Summaries, metrics and TensorBoard
|
||||
|
||||
[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
|
||||
is a popular tool for understanding, debugging and optimizing the model training
|
||||
process. To benefit from the visualizations offered by TensorBoard, summary
|
||||
events need to be written during the course of execution of your program. You
|
||||
might find many Tensorflow programs that include the
|
||||
[`tf.summary`](https://www.tensorflow.org/api_guides/python/summary) operations
|
||||
during graph construction.
|
||||
|
||||
`tf.summary` operations are *not* compatible with eager execution, but an
|
||||
equivalent alternative exists in
|
||||
[`tf.contrib.summary`](https://www.tensorflow.org/versions/master/api_guides/python/tf/contrib/summary/)
|
||||
that is compatible with both eager execution and graph construction.
|
||||
|
||||
During model construction simply insert summary operations like
|
||||
`tf.contrib.summary.scalar`. These operations do nothing by default, unless a
|
||||
summary writer is currently active and a writing policy is set.
|
||||
|
||||
For example, to record summaries once every 100 global steps, use:
|
||||
|
||||
```python
|
||||
tf.train.get_or_create_global_step() # Ensuring the global step variable exists
|
||||
writer = tf.contrib.summary.create_summary_file_writer(logdir)
|
||||
|
||||
for _ in range(iterations):
|
||||
with writer.as_default():
|
||||
with tf.contrib.summary.record_summaries_every_n_global_steps(100):
|
||||
# your model code goes here
|
||||
tf.contrib.summary.scalar('loss', loss)
|
||||
# ...
|
||||
```
|
||||
|
||||
See the full mnist example in
|
||||
[`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist)
|
||||
for a full model using `tf.contrib.summary`.
|
||||
|
||||
Similarly to summaries, the metrics in `tf.metrics` are currently not compatible
|
||||
with eager execution. We instead provide object-oriented metrics in the
|
||||
`tfe.metrics` package, which are compatible with graph construction as well.
|
||||
|
||||
Metrics in the `tfe.metrics`, such as `tfe.metrics.Mean` and
|
||||
`tfe.Metrics.Accuracy`, all implement an intuitive object-oriented
|
||||
interface. Here's an example of how to use the `tfe.metrics.Mean` metric:
|
||||
|
||||
```python
|
||||
# Metrics are objects, which can be created and destroyed.
|
||||
my_mean = tfe.metrics.Mean(name='my_mean')
|
||||
# While a metric is active, you can call it as a function to accumulate into its
|
||||
# internal state.
|
||||
my_mean(0.0)
|
||||
my_mean(10.0)
|
||||
# Once you've finished updating the metric, you can get its result. In this case
|
||||
# a simple average over all the calls to it. If a summary writer is active the
|
||||
# metric will write the appropriate summaries using the metric name.
|
||||
assert 5.0 == my_mean.result().numpy()
|
||||
```
|
||||
|
||||
For a full example of a model using metrics for evaluation, see the mnist
|
||||
example in
|
||||
[`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist).
|
||||
|
||||
### Input Pipelines
|
||||
|
||||
The discussion above has been centered around the computation executed by your
|
||||
model. The
|
||||
[`tf.data`](https://www.tensorflow.org/versions/master/api_docs/python/tf/data)
|
||||
module provides APIs to build complex input pipelines from simple, reusable
|
||||
pieces.
|
||||
|
||||
If you're familiar with constructing `tf.data.Dataset` objects when building
|
||||
TensorFlow graphs, the same API calls are used when eager execution is enabled.
|
||||
However, the process of iterating over elements of the dataset differs between
|
||||
eager execution and graph construction. When eager execution is enabled, the
|
||||
discussion on iterator creation using `make_one_shot_iterator()` and
|
||||
`get_next()` in the
|
||||
[Programmer's
|
||||
Guide](https://www.tensorflow.org/versions/master/programmers_guide/datasets) is
|
||||
*not* applicable. Instead, a more Pythonic `Iterator` class is available.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
# Create a source Dataset from in-memory numpy arrays.
|
||||
# For reading from files on disk, you may want to use other Dataset classes
|
||||
# like the TextLineDataset or the TFRecordDataset.
|
||||
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])
|
||||
|
||||
# Apply transformations, shuffling, batching etc.
|
||||
dataset = dataset.map(tf.square).shuffle(2).batch(2)
|
||||
|
||||
# Use tfe.Iterator to iterate over the dataset.
|
||||
for x in tfe.Iterator(dataset):
|
||||
print(x)
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
tf.Tensor([4 9], shape=(2,), dtype=int32)
|
||||
tf.Tensor([16 25], shape=(2,), dtype=int32)
|
||||
tf.Tensor([36 1], shape=(2,), dtype=int32)
|
||||
```
|
||||
|
||||
## Interoperating with Graphs
|
||||
|
||||
Eager execution improves the process of model development in Python; however,
|
||||
because it is in its earliest stages, it does not yet support some features
|
||||
available to [TensorFlow
|
||||
graphs](https://www.tensorflow.org/get_started/get_started#the_computational_graph)
|
||||
that are desirable when deploying models in production. In particular, eager
|
||||
execution does not yet support distributed training, exporting models (to other
|
||||
[programming languages](https://www.tensorflow.org/api_docs/), [TensorFlow
|
||||
serving](https://www.tensorflow.org/serving/), and mobile applications), and
|
||||
various memory and computation optimizations that are applied to TensorFlow's
|
||||
dataflow graphs.
|
||||
|
||||
That said, the APIs used to build modes are exactly the same whether executing
|
||||
eagerly or constructing graphs. This means that you can iteratively develop your
|
||||
model with eager execution enabled and later, if needed, use the same code to
|
||||
reap the benefits of representing models as computational graphs.
|
||||
|
||||
For example,
|
||||
[`mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py)
|
||||
defines a model that is eagerly executed. That same code is used to construct
|
||||
and execute a graph in
|
||||
[`mnist_graph_test.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py).
|
||||
|
||||
Other models in the [examples
|
||||
directory](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/)
|
||||
demonstrate this as well.
|
||||
|
||||
Some differences worth noting:
|
||||
|
||||
- There is no notion of a `tf.placeholder` or a `tf.Session` when eager
|
||||
execution is enabled.
|
||||
- Many properties on the `tf.Tensor` object, like `tf.Tensor.name`,
|
||||
`tf.Tensor.op`, `tf.Tensor.inputs` are not meaningful when eager execution
|
||||
is enabled and their use will raise an `AttributeError`.
|
||||
- To use `tfe.implicit_gradients` in graph construction, variables must be
|
||||
created with [`use_resource=True`] provided to
|
||||
[`tf.get_variable()`](https://www.tensorflow.org/api_docs/python/tf/get_variable)
|
||||
or
|
||||
[`tf.variable_scope()`](https://www.tensorflow.org/api_docs/python/tf/variable_scope).
|
||||
- Some API calls (such as the functional-style `tf.layers.dense`,
|
||||
`tf.layers.conv2d`) are not compatible with eager execution. Use of such
|
||||
methods should raise an error indicating the alternative (e.g., the
|
||||
`tf.layers.Dense` and `tf.layers.Conv2D` classes).
|
||||
|
||||
## What next?
|
||||
|
||||
Please give eager execution a spin. This feature is in early stages and is
|
||||
evolving, so we welcome your feedback via issues on GitHub (see [known
|
||||
issues](https://github.com/tensorflow/tensorflow/labels/eager)).
|
||||
|
||||
You may want to browse through some sample code, including benchmarks for some:
|
||||
|
||||
- [Linear Regression](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/linear_regression)
|
||||
- [MNIST handwritten digit classifier](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist)
|
||||
- [ResNet50 image classification](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/resnet50)
|
||||
- [RNN to generate colors](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_colorbot)
|
||||
- [RNN language model](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_ptb)
|
||||
|
||||
|
|
@ -153,6 +153,7 @@ sh_binary(
|
|||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/boosted_trees:boosted_trees_pip",
|
||||
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
|
||||
"//tensorflow/contrib/eager/python/examples:examples_pip",
|
||||
"//tensorflow/contrib/gan:gan",
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_pip",
|
||||
"//tensorflow/contrib/keras:keras",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user