eager: Documentation and example models.

- Updated README
- A preliminary "User's Guide"
- A few example models, some with benchmarks

PiperOrigin-RevId: 173996303
This commit is contained in:
Asim Shankar 2017-10-30 22:26:51 -07:00 committed by TensorFlower Gardener
parent cd81bc8e09
commit a6a6188439
31 changed files with 5362 additions and 17 deletions

View File

@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
from tensorflow.contrib.util import loader
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
@ -26,12 +25,8 @@ from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
_cudnn_rnn_ops_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so"))
CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION
CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION
CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM

View File

@ -0,0 +1,15 @@
TensorFlow has many kernels for doing (deep) learning and data manipulation.
There are typically assembled into computational graphs which can run
efficiently in a variety of environments.
We are exploring an alternative interaction, where kernels are invoked
immediately and call this "eager execution". We are hoping to retain the
benefits of graphs while improving usability with benefits like:
- Immediate error messages and easier debugging
- Flexibility to use Python datastructures and control flow
- Reduced boilerplate
Eager execution is under active development.
There are not many developer-facing materials yet, but stay tuned for updates
in this directory.

View File

@ -1,15 +1,78 @@
TensorFlow has many kernels for doing (deep) learning and data manipulation. # TensorFlow Eager Execution
There are typically assembled into computational graphs which can run
efficiently in a variety of environments.
We are exploring an alternative interaction, where kernels are invoked > *WARNING*: This is a preview/pre-alpha version. The API and performance
immediately and call this "eager execution". We are hoping to retain the > characteristics are subject to change.
benefits of graphs while improving usability with benefits like:
- Immediate error messages and easier debugging Eager execution is an experimental interface to TensorFlow that provides an
- Flexibility to use Python datastructures and control flow imperative programming style (à la [NumPy](http://www.numpy.org)). When you
- Reduced boilerplate enable eager execution, TensorFlow operations execute immediately; you do not
execute a pre-constructed graph with
[`Session.run()`](https://www.tensorflow.org/api_docs/python/tf/Session).
Eager execution is under active development. For example, consider a simple computation in TensorFlow:
There are not many developer-facing materials yet, but stay tuned for updates
in this directory. ```python
x = tf.placeholder(tf.float32, shape=[1, 1])
m = tf.matmul(x, x)
with tf.Session() as sess:
print(sess.run(m, feed_dict={x: [[2.]]}))
# Will print [[4.]]
```
Eager execution makes this much simpler:
```python
x = [[2.]]
m = tf.matmul(x, x)
print(m)
```
## Caveats
This feature is in early stages and work remains to be done in terms of smooth
support for distributed and multi-GPU training and CPU performance.
- [Known issues](https://github.com/tensorflow/tensorflow/issues?q=is%3Aissue%20is%3Aopen%20label%3Aproj%3Aeager)
- Feedback is welcome, please consider
[filing an issue](https://github.com/tensorflow/tensorflow/issues/new) to provide it.
## Installation
Since eager execution is not yet part of a TensorFlow release, using it requires
either [building from source](https://www.tensorflow.org/install/install_sources)
or the latest nightly builds. The nightly builds are available as:
- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and
- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images.
For example, to run the latest nightly docker image:
```sh
# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker
nvidia-docker pull tensorflow/tensorflow:nightly-gpu
nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu
# If you do not have a GPU, use the CPU-only image
docker pull tensorflow/tensorflow:nightly
docker run -it -p 8888:8888 tensorflow/tensorflow:nightly
```
And then visit http://localhost:8888 in your browser for a Jupyter notebook
environment. Try out the notebooks below.
## Documentation
For an introduction to eager execution in TensorFlow, see:
- [User Guide](python/g3doc/guide.md)
- Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb)
- Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb)
- Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb)
## Changelog
- 2017/10/31: Initial preview release.

View File

@ -0,0 +1,15 @@
# TensorFlow code for training gradient boosted trees.
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
py_library(
name = "examples_pip",
deps = [
"//tensorflow/contrib/eager/python/examples/linear_regression",
"//tensorflow/contrib/eager/python/examples/mnist",
"//tensorflow/contrib/eager/python/examples/resnet50",
"//tensorflow/contrib/eager/python/examples/rnn_colorbot",
"//tensorflow/contrib/eager/python/examples/rnn_ptb",
],
)

View File

@ -0,0 +1,25 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_binary(
name = "linear_regression",
srcs = ["linear_regression.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/eager/python:tfe",
],
)
cuda_py_test(
name = "linear_regression_test",
size = "small",
srcs = ["linear_regression_test.py"],
additional_deps = [
":linear_regression",
"//tensorflow:tensorflow_py",
],
)

View File

@ -0,0 +1,157 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""TensorFlow Eager Execution Example: Linear Regression.
This example shows how to use TensorFlow Eager Execution to fit a simple linear
regression model using some synthesized data. Specifically, it illustrates how
to define the forward path of the linear model and the loss function, as well
as how to obtain the gradients of the loss function with respect to the
variables and update the variables with the gradients.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
import tensorflow.contrib.eager as tfe
class LinearModel(tfe.Network):
"""A TensorFlow linear regression model.
Uses TensorFlow's eager execution.
For those familiar with TensorFlow graphs, notice the absence of
`tf.Session`. The `forward()` method here immediately executes and
returns output values. The `loss()` method immediately compares the
output of `forward()` with the target adn returns the MSE loss value.
The `fit()` performs gradient-descent training on the model's weights
and bias.
"""
def __init__(self):
"""Constructs a LinearModel object."""
super(LinearModel, self).__init__()
self._hidden_layer = self.track_layer(tf.layers.Dense(1))
def call(self, xs):
"""Invoke the linear model.
Args:
xs: input features, as a tensor of size [batch_size, ndims].
Returns:
ys: the predictions of the linear mode, as a tensor of size [batch_size]
"""
return self._hidden_layer(xs)
def fit(model, dataset, optimizer, verbose=False, logdir=None):
"""Fit the linear-regression model.
Args:
model: The LinearModel to fit.
dataset: The tf.data.Dataset to use for training data.
optimizer: The TensorFlow Optimizer object to be used.
verbose: If true, will print out loss values at every iteration.
logdir: The directory in which summaries will be written for TensorBoard
(optional).
"""
# The loss function to optimize.
def mean_square_loss(xs, ys):
return tf.reduce_mean(tf.square(model(xs) - ys))
loss_and_grads = tfe.implicit_value_and_gradients(mean_square_loss)
tf.train.get_or_create_global_step()
if logdir:
# Support for TensorBoard summaries. Once training has started, use:
# tensorboard --logdir=<logdir>
summary_writer = tf.contrib.summary.create_summary_file_writer(logdir)
# Training loop.
for i, (xs, ys) in enumerate(tfe.Iterator(dataset)):
loss, grads = loss_and_grads(xs, ys)
if verbose:
print("Iteration %d: loss = %s" % (i, loss.numpy()))
optimizer.apply_gradients(grads, global_step=tf.train.get_global_step())
if logdir:
with summary_writer.as_default():
with tf.contrib.summary.always_record_summaries():
tf.contrib.summary.scalar("loss", loss)
def synthetic_dataset(w, b, noise_level, batch_size, num_batches):
"""tf.data.Dataset that yields synthetic data for linear regression."""
# w is a matrix with shape [N, M]
# b is a vector with shape [M]
# So:
# - Generate x's as vectors with shape [batch_size N]
# - y = tf.matmul(x, W) + b + noise
def batch(_):
x = tf.random_normal([batch_size, tf.shape(w)[0]])
y = tf.matmul(x, w) + b + noise_level * tf.random_normal([])
return x, y
with tf.device("/device:CPU:0"):
return tf.data.Dataset.range(num_batches).map(batch)
def main(_):
tfe.enable_eager_execution()
# Ground-truth constants.
true_w = [[-2.0], [4.0], [1.0]]
true_b = [0.5]
noise_level = 0.01
# Training constants.
batch_size = 64
learning_rate = 0.1
print("True w: %s" % true_w)
print("True b: %s\n" % true_b)
model = LinearModel()
dataset = synthetic_dataset(true_w, true_b, noise_level, batch_size, 20)
device = "gpu:0" if tfe.num_gpus() else "cpu:0"
print("Using device: %s" % device)
with tf.device(device):
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
fit(model, dataset, optimizer, verbose=True, logdir=FLAGS.logdir)
print("\nAfter training: w = %s" % model.variables[0].numpy())
print("\nAfter training: b = %s" % model.variables[1].numpy())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--logdir",
type=str,
default=None,
help="logdir in which TensorBoard summaries will be written (optional).")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,119 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for linear regression example under TensorFlow eager execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os
import shutil
import tempfile
import time
import tensorflow as tf
import tensorflow.contrib.eager as tfe
from tensorflow.contrib.eager.python.examples.linear_regression import linear_regression
def device():
return "/device:GPU:0" if tfe.num_gpus() > 0 else "/device:CPU:0"
class LinearRegressionTest(tf.test.TestCase):
def setUp(self):
super(LinearRegressionTest, self).setUp()
self._tmp_logdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self._tmp_logdir)
super(LinearRegressionTest, self).tearDown()
def testSyntheticDataset(self):
true_w = tf.random_uniform([3, 1])
true_b = [1.0]
batch_size = 10
num_batches = 2
noise_level = 0.
dataset = linear_regression.synthetic_dataset(true_w, true_b, noise_level,
batch_size, num_batches)
it = tfe.Iterator(dataset)
for _ in range(2):
(xs, ys) = it.next()
self.assertEqual((batch_size, 3), xs.shape)
self.assertEqual((batch_size, 1), ys.shape)
self.assertEqual(tf.float32, xs.dtype)
self.assertEqual(tf.float32, ys.dtype)
with self.assertRaises(StopIteration):
it.next()
def testLinearRegression(self):
true_w = [[1.0], [-0.5], [2.0]]
true_b = [1.0]
model = linear_regression.LinearModel()
dataset = linear_regression.synthetic_dataset(
true_w, true_b, noise_level=0., batch_size=64, num_batches=40)
with tf.device(device()):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
linear_regression.fit(model, dataset, optimizer, logdir=self._tmp_logdir)
self.assertAllClose(true_w, model.variables[0].numpy(), rtol=1e-2)
self.assertAllClose(true_b, model.variables[1].numpy(), rtol=1e-2)
self.assertTrue(glob.glob(os.path.join(self._tmp_logdir, "events.out.*")))
class EagerLinearRegressionBenchmark(tf.test.Benchmark):
def benchmarkEagerLinearRegression(self):
num_batches = 200
batch_size = 64
dataset = linear_regression.synthetic_dataset(
w=tf.random_uniform([3, 1]),
b=tf.random_uniform([1]),
noise_level=0.01,
batch_size=batch_size,
num_batches=num_batches)
burn_in_dataset = dataset.take(10)
model = linear_regression.LinearModel()
with tf.device(device()):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
# Perform burn-in.
linear_regression.fit(model, burn_in_dataset, optimizer)
start_time = time.time()
linear_regression.fit(model, dataset, optimizer)
wall_time = time.time() - start_time
examples_per_sec = num_batches * batch_size / wall_time
self.report_benchmark(
name="eager_train_%s" %
("gpu" if tfe.num_gpus() > 0 else "cpu"),
iters=num_batches,
extras={"examples_per_sec": examples_per_sec},
wall_time=wall_time)
if __name__ == "__main__":
tfe.enable_eager_execution()
tf.test.main()

View File

@ -0,0 +1,36 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_binary(
name = "mnist",
srcs = ["mnist.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow/examples/tutorials/mnist:input_data",
],
)
cuda_py_test(
name = "mnist_test",
srcs = ["mnist_test.py"],
additional_deps = [
":mnist",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow:tensorflow_py",
],
)
cuda_py_test(
name = "mnist_graph_test",
srcs = ["mnist_graph_test.py"],
additional_deps = [
":mnist",
"//third_party/py/numpy",
"//tensorflow:tensorflow_py",
],
)

View File

@ -0,0 +1,10 @@
Classification model for the MNIST dataset using eager execution.
To run:
```
python mnist.py
```
`mnist_graph_test.py` demonstrates that the same code that is executed eagerly
in `mnist.py` is used to construct a TensorFlow graph.

View File

@ -0,0 +1,270 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A deep MNIST classifier using convolutional layers.
Sample usage:
python mnist.py --help
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import functools
import os
import sys
import time
import tensorflow as tf
import tensorflow.contrib.eager as tfe
from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None
class MNISTModel(tfe.Network):
"""MNIST Network.
Network structure is equivalent to:
https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/examples/tutorials/mnist/mnist_deep.py
and
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
But written using the tf.layers API.
"""
def __init__(self, data_format):
"""Creates a model for classifying a hand-written digit.
Args:
data_format: Either 'channels_first' or 'channels_last'.
'channels_first' is typically faster on GPUs while 'channels_last' is
typically faster on CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
"""
super(MNISTModel, self).__init__(name='')
if data_format == 'channels_first':
self._input_shape = [-1, 1, 28, 28]
else:
assert data_format == 'channels_last'
self._input_shape = [-1, 28, 28, 1]
self.conv1 = self.track_layer(
tf.layers.Conv2D(32, 5, data_format=data_format, activation=tf.nn.relu))
self.conv2 = self.track_layer(
tf.layers.Conv2D(64, 5, data_format=data_format, activation=tf.nn.relu))
self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.nn.relu))
self.fc2 = self.track_layer(tf.layers.Dense(10))
self.dropout = self.track_layer(tf.layers.Dropout(0.5))
self.max_pool2d = self.track_layer(
tf.layers.MaxPooling2D(
(2, 2), (2, 2), padding='SAME', data_format=data_format))
def call(self, inputs, training):
"""Computes labels from inputs.
Users should invoke __call__ to run the network, which delegates to this
method (and not call this method directly).
Args:
inputs: A batch of images as a Tensor with shape [batch_size, 784].
training: True if invoked in the context of training (causing dropout to
be applied). False otherwise.
Returns:
A Tensor with shape [batch_size, 10] containing the predicted logits
for each image in the batch, for each of the 10 classes.
"""
x = tf.reshape(inputs, self._input_shape)
x = self.conv1(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.max_pool2d(x)
x = tf.layers.flatten(x)
x = self.fc1(x)
if training:
x = self.dropout(x)
x = self.fc2(x)
return x
def loss(predictions, labels):
return tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
logits=predictions, labels=labels))
def compute_accuracy(predictions, labels):
return tf.reduce_sum(
tf.cast(
tf.equal(
tf.argmax(predictions, axis=1,
output_type=tf.int64),
tf.argmax(labels, axis=1,
output_type=tf.int64)),
dtype=tf.float32)) / float(predictions.shape[0].value)
def train_one_epoch(model, optimizer, dataset, log_interval=None):
"""Trains model on `dataset` using `optimizer`."""
tf.train.get_or_create_global_step()
def model_loss(labels, images):
prediction = model(images, training=True)
loss_value = loss(prediction, labels)
tf.contrib.summary.scalar('loss', loss_value)
tf.contrib.summary.scalar('accuracy',
compute_accuracy(prediction, labels))
return loss_value
for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)):
with tf.contrib.summary.record_summaries_every_n_global_steps(10):
batch_model_loss = functools.partial(model_loss, labels, images)
optimizer.minimize(
batch_model_loss, global_step=tf.train.get_global_step())
if log_interval and batch % log_interval == 0:
print('Batch #%d\tLoss: %.6f' % (batch, batch_model_loss()))
def test(model, dataset):
"""Perform an evaluation of `model` on the examples from `dataset`."""
avg_loss = tfe.metrics.Mean('loss')
accuracy = tfe.metrics.Accuracy('accuracy')
for (images, labels) in tfe.Iterator(dataset):
predictions = model(images, training=False)
avg_loss(loss(predictions, labels))
accuracy(tf.argmax(predictions, axis=1, output_type=tf.int64),
tf.argmax(labels, axis=1, output_type=tf.int64))
print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' %
(avg_loss.result(), 100 * accuracy.result()))
with tf.contrib.summary.always_record_summaries():
tf.contrib.summary.scalar('loss', avg_loss.result())
tf.contrib.summary.scalar('accuracy', accuracy.result())
def load_data(data_dir):
"""Returns training and test tf.data.Dataset objects."""
data = input_data.read_data_sets(data_dir, one_hot=True)
train_ds = tf.data.Dataset.from_tensor_slices((data.train.images,
data.train.labels))
test_ds = tf.data.Dataset.from_tensors((data.test.images, data.test.labels))
return (train_ds, test_ds)
def main(_):
tfe.enable_eager_execution()
(device, data_format) = ('/gpu:0', 'channels_first')
if FLAGS.no_gpu or tfe.num_gpus() <= 0:
(device, data_format) = ('/cpu:0', 'channels_last')
print('Using device %s, and data format %s.' % (device, data_format))
# Load the datasets
(train_ds, test_ds) = load_data(FLAGS.data_dir)
train_ds = train_ds.shuffle(60000).batch(FLAGS.batch_size)
# Create the model and optimizer
model = MNISTModel(data_format)
optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum)
if FLAGS.output_dir:
train_dir = os.path.join(FLAGS.output_dir, 'train')
test_dir = os.path.join(FLAGS.output_dir, 'eval')
tf.gfile.MakeDirs(FLAGS.output_dir)
else:
train_dir = None
test_dir = None
summary_writer = tf.contrib.summary.create_summary_file_writer(
train_dir, flush_secs=10)
test_summary_writer = tf.contrib.summary.create_summary_file_writer(
test_dir, flush_secs=10, name='test')
checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
with tf.device(device):
for epoch in range(1, 11):
with tfe.restore_variables_on_create(
tf.train.latest_checkpoint(FLAGS.checkpoint_dir)):
global_step = tf.train.get_or_create_global_step()
start = time.time()
with summary_writer.as_default():
train_one_epoch(model, optimizer, train_ds, FLAGS.log_interval)
end = time.time()
print('\nTrain time for epoch #%d (global step %d): %f' % (
epoch, global_step.numpy(), end - start))
with test_summary_writer.as_default():
test(model, test_ds)
all_variables = (
model.variables
+ tfe.get_optimizer_variables(optimizer)
+ [global_step])
tfe.Saver(all_variables).save(
checkpoint_prefix, global_step=global_step)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data-dir',
type=str,
default='/tmp/tensorflow/mnist/input_data',
help='Directory for storing input data')
parser.add_argument(
'--batch-size',
type=int,
default=64,
metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument(
'--log-interval',
type=int,
default=10,
metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument(
'--output_dir',
type=str,
default=None,
metavar='N',
help='Directory to write TensorBoard summaries')
parser.add_argument(
'--checkpoint_dir',
type=str,
default='/tmp/tensorflow/mnist/checkpoints/',
metavar='N',
help='Directory to save checkpoints in (once per epoch)')
parser.add_argument(
'--lr',
type=float,
default=0.01,
metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument(
'--momentum',
type=float,
default=0.5,
metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument(
'--no-gpu',
action='store_true',
default=False,
help='disables GPU usage even if a GPU is available')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,65 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.mnist import mnist
def data_format():
return "channels_first" if tf.test.is_gpu_available() else "channels_last"
class MNISTGraphTest(tf.test.TestCase):
def testTrainGraph(self):
# The MNISTModel class can be executed eagerly (as in mnist.py and
# mnist_test.py) and also be used to construct a TensorFlow graph, which is
# then trained in a session.
with tf.Graph().as_default():
# Generate some random data.
batch_size = 64
images = np.random.randn(batch_size, 784).astype(np.float32)
digits = np.random.randint(low=0, high=10, size=batch_size)
labels = np.zeros((batch_size, 10))
labels[np.arange(batch_size), digits] = 1.
# Create a model, optimizer, and dataset as would be done
# for eager execution as well.
model = mnist.MNISTModel(data_format())
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
dataset = tf.data.Dataset.from_tensors((images, labels))
# Define the loss tensor (as opposed to a loss function when
# using eager execution).
(images, labels) = dataset.make_one_shot_iterator().get_next()
predictions = model(images, training=True)
loss = mnist.loss(predictions, labels)
train_op = optimizer.minimize(loss)
init = tf.global_variables_initializer()
with tf.Session() as sess:
# Variables have to be initialized in the session.
sess.run(init)
# Train using the optimizer.
sess.run(train_op)
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,62 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.contrib.eager as tfe
from tensorflow.contrib.eager.python.examples.mnist import mnist
def device():
return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0"
def data_format():
return "channels_first" if tfe.num_gpus() else "channels_last"
def random_dataset():
batch_size = 64
images = tf.random_normal([batch_size, 784])
digits = tf.random_uniform([batch_size], minval=0, maxval=10, dtype=tf.int32)
labels = tf.one_hot(digits, 10)
return tf.data.Dataset.from_tensors((images, labels))
class MNISTTest(tf.test.TestCase):
def testTrainOneEpoch(self):
model = mnist.MNISTModel(data_format())
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
dataset = random_dataset()
with tf.device(device()):
tf.train.get_or_create_global_step()
mnist.train_one_epoch(model, optimizer, dataset)
def testTest(self):
model = mnist.MNISTModel(data_format())
dataset = random_dataset()
with tf.device(device()):
tf.train.get_or_create_global_step()
mnist.test(model, dataset)
if __name__ == "__main__":
tfe.enable_eager_execution()
tf.test.main()

View File

@ -0,0 +1,529 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "U9i2Dsh-ziXr"
},
"source": [
"# Eager Execution Tutorial: Basics\n",
"\n",
"This notebook introduces the basics of using TensorFlow's eager execution capabilities. It covers concepts such as:\n",
"\n",
"* Importing required packages\n",
"* Enabling eager execution\n",
"* Creating and using TensorFlow Tensors and Variables\n",
"* Using TensorFlow interactively\n",
"* Using GPUs with eager execution enabled\n",
"\n",
"This notebook does *not* cover modeling topics, such as gradients."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "z1JcS5iBXMRO"
},
"source": [
"# Step 1: Import Eager\n",
"\n",
"The key imports for eager execution are the following:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "RlIWhyeLoYnG"
},
"outputs": [],
"source": [
"# Import TensorFlow.\n",
"import tensorflow as tf\n",
"\n",
"# Import TensorFlow eager execution support (subject to future changes).\n",
"import tensorflow.contrib.eager as tfe"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "H9UySOPLXdaw"
},
"source": [
"# Step 2: Enable eager execution\n",
"\n",
"All future TensorFlow calls will execute the\n",
"underlying TensorFlow ops immediately:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "WPTUfGq6kJ5w"
},
"outputs": [],
"source": [
"tfe.enable_eager_execution()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "twBfWd5xyu_d"
},
"source": [
"# Step 3: Interactively Use TensorFlow!\n",
"\n",
"Now you can call TensorFlow functions and get results, immediately! No more `tf.Sessions`!\n",
"\n",
"TensorFlow will automatically wrap native Python types for you with operator overloading for TensorFlow Tensors."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "ngUe237Wt48W"
},
"outputs": [],
"source": [
"print(tf.add(1, 2))\n",
"print(tf.add([1, 2], [3, 4]))\n",
"print(tf.square(5))\n",
"print(tf.reduce_sum([1, 2, 3]))\n",
"print(tf.encode_base64(\"hello world\"))\n",
"print(\"\")\n",
"\n",
"x = tf.constant(2)\n",
"y = tf.constant(3)\n",
"print(x * y + 1)\n",
"\n",
"# Most TensorFlow ops are directly usable with eager execution, giving\n",
"# results immediately.\n",
"print(tf.contrib.signal.hamming_window(x * y + 1))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "IDY4WsYRhP81"
},
"source": [
"Numpy arrays are supported, too:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "lCUWzso6mbqR"
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"ones = np.ones([3, 3])\n",
"\n",
"print(\"numpy 3x3 matrix of 1s:\")\n",
"print(ones)\n",
"print(\"\")\n",
"\n",
"print(\"Multiplied by 42:\")\n",
"print(tf.multiply(ones, 42))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PBNP8yTRfu_X"
},
"source": [
"# Step 4: Define and Print TensorFlow Variables\n",
"\n",
"To define TensorFlow variables, use the `get_variable()` function as follows:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "3Twf_Rw-gQFM"
},
"outputs": [],
"source": [
"x = tf.get_variable(name=\"x\", shape=[], dtype=tf.float32, initializer=tf.zeros_initializer)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "45G7094TxsMb"
},
"source": [
"## Printing TensorFlow Variables"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "UJBJeZ5XxuwA"
},
"outputs": [],
"source": [
"# This does NOT print the Variable's actual value:\n",
"print(\"Printing a TensorFlow Variable:\")\n",
"print(x)\n",
"print(\"\")\n",
"\n",
"# A TensorFlow variable represents a reference to a tensor.\n",
"# The `read_value()` method provides access to the current value of the\n",
"# variable. Tensorflow Variables are automatically initialized according to the\n",
"# semantics defined in tf.get_variable().\n",
"print(\"Printing a TensorFlow Variable's value using .read_value():\")\n",
"print(x.read_value())\n",
"print(\"\")\n",
"\n",
"print(\"Printing a TensorFlow Variable's value using .read_value().numpy():\")\n",
"print(x.read_value().numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "2njjWHcTpBEn"
},
"source": [
"## Changing a TensorFlow Variable's value\n",
"\n",
"To change a TensorFlow Variable's value, use its `.assign()` or `.assign_add()` method:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "v3wr6Erbo_hB"
},
"outputs": [],
"source": [
"x.assign(42)\n",
"print(x.read_value())\n",
"\n",
"x.assign_add(3)\n",
"print(x.read_value())"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "uhtynjHVpTB5"
},
"source": [
"## Use a Variable just like any other Tensor"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "7PbktdnHoehR"
},
"outputs": [],
"source": [
"print(x + 3)\n",
"\n",
"# This code will broadcast the value across the list of numbers:\n",
"print(x * [1, 2, 4])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "GVChqwlwy1SI"
},
"source": [
"# Step 5: Debug Errors with Instant Feedback\n",
"\n",
"TensorFlow's eager execution helps you identify and debug runtime issues through interactive exploration of code snippets.\n",
"\n",
"Below, we'll define a length-4 vector, and attempt two `tf.slice()` operations,\n",
"one being legal and the other being illegal, leading to a runtime error that is\n",
"raised immediately."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "23ap04N0v4k0"
},
"outputs": [],
"source": [
"vector = tf.constant([10.0, 20.0, 30.0, 40.0])"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "FCUMsIYxxRRa"
},
"outputs": [],
"source": [
"# Works, because the values of `begin` and `size` (the 2nd and 3rd input\n",
"# arguments) are within the bound of `vector`.\n",
"print(tf.slice(vector, [1], [3]))"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "T8me2oCNxpFp"
},
"outputs": [],
"source": [
"# The following does NOT work, because the value of `size` (the 3rd\n",
"# argument) causes the indices to go out of the bounds of `vector`. The\n",
"# error is raised immediately.\n",
"try:\n",
" print(tf.slice(vector, [1], [4]))\n",
"except tf.OpError as e:\n",
" print(\"Caught error: %s\" % e)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "irxJhAgar84v"
},
"source": [
"# Step 6: Using the GPU\n",
"\n",
"You can place Tensors on the GPU by calling a Tensor's `.gpu()` method.\n",
"\n",
"The first operation executing on the GPU may be slow as TensorFlow initializes. Subsequent uses will be much faster."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "7J4N9baqaKCL"
},
"outputs": [],
"source": [
"# The example code from here on will work only if your notebook\n",
"# is running on a machine with a functional CUDA GPU. The following\n",
"# line checks that.\n",
"is_gpu_available = tfe.num_gpus() \u003e 0\n",
"\n",
"# Create some Tensors\n",
"SIZE = 1000\n",
"cpu_tensor = tf.random_normal([SIZE, SIZE])\n",
"\n",
"if is_gpu_available:\n",
" gpu_tensor = cpu_tensor.gpu()"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "4E-2n7VbzY1n"
},
"outputs": [],
"source": [
"# Time a CPU-based matrix multiplication\n",
"\n",
"print(\"Time to conduct matmul on CPU:\")\n",
"%time tf.matmul(cpu_tensor, cpu_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "vbSFW-T5zhZF"
},
"outputs": [],
"source": [
"# Time GPU-based matrix multiplications.\n",
"\n",
"if is_gpu_available:\n",
" # First use of the GPU will be slow:\n",
" print(\"Time to conduct first matmul on GPU:\")\n",
" %time tf.matmul(gpu_tensor, gpu_tensor)\n",
" print()\n",
"\n",
" # Subsequent uses are much faster:\n",
" print(\"Time to conduct second matmul on GPU:\")\n",
" %time tf.matmul(gpu_tensor, gpu_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "E5pIOe3Rz7iW"
},
"outputs": [],
"source": [
"# Second timing demo for GPUs, after it has been used once:\n",
"\n",
"cpu_tensor = tf.random_normal([SIZE, SIZE])\n",
"print(\"Time to conduct CPU matmul:\")\n",
"%time tf.matmul(cpu_tensor, cpu_tensor)\n",
"print()\n",
"\n",
"if is_gpu_available:\n",
" gpu_tensor = cpu_tensor.gpu()\n",
" print(\"Time to conduct GPU matmul:\")\n",
" %time tf.matmul(gpu_tensor, gpu_tensor)"
]
}
],
"metadata": {
"colab": {
"default_view": {},
"name": "Eager Execution Tutorial: Basics",
"provenance": [
{
"file_id": "0B0kLcpwLFwKEVm9XNkFueGk4bTg",
"timestamp": 1504118841551
}
],
"version": "0.3.2",
"views": {}
}
},
"nbformat": 4,
"nbformat_minor": 0
}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,218 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "U9i2Dsh-ziXr"
},
"source": [
"# Eager Execution Tutorial: Importing Data\n",
"\n",
"This notebook demonstrates the use of the [`tf.contrib.data.Dataset` API](https://www.tensorflow.org/programmers_guide/datasets) to build pipelines to feed data to your program. It covers:\n",
"\n",
"* Creating a `Dataset`.\n",
"* Iteration over a `Dataset` with eager execution enabled.\n",
"\n",
"We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n",
"\n",
"If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly different. You will use a Pythonic `Iterator()` class instead of using `make_one_shot_iterator()` and `get_next()`. As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "z1JcS5iBXMRO"
},
"source": [
"# Setup: Enable eager execution\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "RlIWhyeLoYnG"
},
"outputs": [],
"source": [
"# Import TensorFlow.\n",
"import tensorflow as tf\n",
"\n",
"# Import TensorFlow eager execution support (subject to future changes).\n",
"import tensorflow.contrib.eager as tfe\n",
"\n",
"# Enable eager execution\n",
"tfe.enable_eager_execution()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "H9UySOPLXdaw"
},
"source": [
"# Step 1: Create a source `Dataset`\n",
"\n",
"Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/TFRecordDataset). See the [Programmer's Guide](https://www.google.com/url?sa=D\u0026q=https%3A%2F%2Fwww.tensorflow.org%2Fprogrammers_guide%2Fdatasets%23reading_input_data) for more information."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "WPTUfGq6kJ5w"
},
"outputs": [],
"source": [
"ds_tensors = tf.contrib.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n",
"\n",
"# Create a CSV file\n",
"import tempfile\n",
"_, filename = tempfile.mkstemp()\n",
"with open(filename, 'w') as f:\n",
" f.write(\"\"\"Line 1\n",
"Line 2\n",
"Line 3\n",
" \"\"\")\n",
"ds_file = tf.contrib.data.TextLineDataset(filename)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "twBfWd5xyu_d"
},
"source": [
"# Step 2: Apply transformations\n",
"\n",
"Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.contrib.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset) for details."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "ngUe237Wt48W"
},
"outputs": [],
"source": [
"ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n",
"ds_file = ds_file.batch(2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "IDY4WsYRhP81"
},
"source": [
"# Step 3: Iterate\n",
"\n",
"Use `tfe.Iterator` on the `Dataset` object to get a Python iterator over the contents of the dataset.\n",
"\n",
"If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that this process of iteration is different. Here there are no calls to `Dataset.make_one_shot_iterator()` and no `get_next()` calls."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"height": 153,
"output_extras": [
{
"item_id": 1
}
]
},
"colab_type": "code",
"executionInfo": {
"elapsed": 201,
"status": "ok",
"timestamp": 1505952405928,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 420
},
"id": "lCUWzso6mbqR",
"outputId": "ec027d30-96c6-4ea4-9ee1-ef74ec1ae29a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Elements of ds_tensors:\n",
"tf.Tensor([4 9], shape=(2,), dtype=int32)\n",
"tf.Tensor([16 25], shape=(2,), dtype=int32)\n",
"tf.Tensor([36 1], shape=(2,), dtype=int32)\n",
"\n",
"Elements in ds_file:\n",
"tf.Tensor(['Line 1' 'Line 2'], shape=(2,), dtype=string)\n",
"tf.Tensor(['Line 3' ' '], shape=(2,), dtype=string)\n"
]
}
],
"source": [
"print('Elements of ds_tensors:')\n",
"for x in tfe.Iterator(ds_tensors):\n",
" print(x)\n",
"\n",
"print('\\nElements in ds_file:')\n",
"for x in tfe.Iterator(ds_file):\n",
" print(x)"
]
}
],
"metadata": {
"colab": {
"default_view": {},
"last_runtime": {
"build_target": "",
"kind": "local"
},
"name": "Eager Execution Tutorial: Importing Data",
"provenance": [],
"version": "0.3.2",
"views": {}
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -0,0 +1,43 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
name = "resnet50",
srcs = ["resnet50.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/eager/python:tfe",
],
)
cuda_py_test(
name = "resnet50_test",
size = "large",
srcs = ["resnet50_test.py"],
additional_deps = [
":resnet50",
"//tensorflow/contrib/summary:summary_test_util",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow:tensorflow_py",
],
)
cuda_py_test(
name = "resnet50_graph_test",
size = "large",
srcs = ["resnet50_graph_test.py"],
additional_deps = [
":resnet50",
"//tensorflow/contrib/summary:summary_test_util",
"//third_party/py/numpy",
"//tensorflow:tensorflow_py",
],
tags = [
"noasan",
"nomsan",
],
)

View File

@ -0,0 +1,34 @@
Image classification using the ResNet50 model described in
[Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385).
Contents:
- `resnet50.py`: Model definition
- `resnet50_test.py`: Sanity unittests and benchmarks for using the model with
eager execution enabled.
- `resnet50_graph_test.py`: Sanity unittests and benchmarks when using the same
model code to construct a TensorFlow graph.
# Benchmarks
Using a synthetic data.
```
# Using eager execution
bazel run -c opt --config=cuda :resnet50_test -- --benchmarks=.
# Using graph execution
bazel run -c opt --config=cuda :resnet50_graph_test -- --benchmarks=.
```
(Or remove the `--config=cuda` flag for running on CPU instead of GPU).
On October 31, 2017, the benchmarks demostrated comparable performance
for eager and graph execution of this particular model when using
a single NVIDIA Titan X (Pascal) GPU on a host with an
Intel Xeon E5-1650 CPU @ 3.50GHz and a batch size of 32.
| Benchmark name | batch size | images/second |
| --------------------------------------- | ------------- | ------------- |
| eager_train_gpu_batch_32_channels_first | 32 | 171 |
| graph_train_gpu_batch_32_channels_first | 32 | 172 |

View File

@ -0,0 +1,324 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ResNet50 model definition compatible with TensorFlow's eager execution.
Reference [Deep Residual Learning for Image
Recognition](https://arxiv.org/abs/1512.03385)
Adapted from tf.keras.applications.ResNet50. A notable difference is that the
model here outputs logits while the Keras model outputs probability.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
import tensorflow.contrib.eager as tfe
class _IdentityBlock(tfe.Network):
"""_IdentityBlock is the block that has no conv layer at shortcut.
Args:
kernel_size: the kernel size of middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
data_format: data_format for the input ('channels_first' or
'channels_last').
"""
def __init__(self, kernel_size, filters, stage, block, data_format):
super(_IdentityBlock, self).__init__(name='')
filters1, filters2, filters3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
bn_axis = 1 if data_format == 'channels_first' else 3
self.conv2a = self.track_layer(
tf.layers.Conv2D(
filters1, (1, 1),
name=conv_name_base + '2a',
data_format=data_format))
self.bn2a = self.track_layer(
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a'))
self.conv2b = self.track_layer(
tf.layers.Conv2D(
filters2,
kernel_size,
padding='same',
data_format=data_format,
name=conv_name_base + '2b'))
self.bn2b = self.track_layer(
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b'))
self.conv2c = self.track_layer(
tf.layers.Conv2D(
filters3, (1, 1),
name=conv_name_base + '2c',
data_format=data_format))
self.bn2c = self.track_layer(
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c'))
def call(self, input_tensor, training=False):
x = self.conv2a(input_tensor)
x = self.bn2a(x, training=training)
x = tf.nn.relu(x)
x = self.conv2b(x)
x = self.bn2b(x, training=training)
x = tf.nn.relu(x)
x = self.conv2c(x)
x = self.bn2c(x, training=training)
x += input_tensor
return tf.nn.relu(x)
class _ConvBlock(tfe.Network):
"""_ConvBlock is the block that has a conv layer at shortcut.
Args:
kernel_size: the kernel size of middle conv layer at main path
filters: list of integers, the filterss of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
data_format: data_format for the input ('channels_first' or
'channels_last').
strides: strides for the convolution. Note that from stage 3, the first
conv layer at main path is with strides=(2,2), and the shortcut should
have strides=(2,2) as well.
"""
def __init__(self,
kernel_size,
filters,
stage,
block,
data_format,
strides=(2, 2)):
super(_ConvBlock, self).__init__(name='')
filters1, filters2, filters3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
bn_axis = 1 if data_format == 'channels_first' else 3
self.conv2a = self.track_layer(
tf.layers.Conv2D(
filters1, (1, 1),
strides=strides,
name=conv_name_base + '2a',
data_format=data_format))
self.bn2a = self.track_layer(
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a'))
self.conv2b = self.track_layer(
tf.layers.Conv2D(
filters2,
kernel_size,
padding='same',
name=conv_name_base + '2b',
data_format=data_format))
self.bn2b = self.track_layer(
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b'))
self.conv2c = self.track_layer(
tf.layers.Conv2D(
filters3, (1, 1),
name=conv_name_base + '2c',
data_format=data_format))
self.bn2c = self.track_layer(
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c'))
self.conv_shortcut = self.track_layer(
tf.layers.Conv2D(
filters3, (1, 1),
strides=strides,
name=conv_name_base + '1',
data_format=data_format))
self.bn_shortcut = self.track_layer(
tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '1'))
def call(self, input_tensor, training=False):
x = self.conv2a(input_tensor)
x = self.bn2a(x, training=training)
x = tf.nn.relu(x)
x = self.conv2b(x)
x = self.bn2b(x, training=training)
x = tf.nn.relu(x)
x = self.conv2c(x)
x = self.bn2c(x, training=training)
shortcut = self.conv_shortcut(input_tensor)
shortcut = self.bn_shortcut(shortcut, training=training)
x += shortcut
return tf.nn.relu(x)
class ResNet50(tfe.Network):
"""Instantiates the ResNet50 architecture.
Args:
data_format: format for the image. Either 'channels_first' or
'channels_last'. 'channels_first' is typically faster on GPUs while
'channels_last' is typically faster on CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
name: Prefix applied to names of variables created in the model.
trainable: Is the model trainable? If true, performs backward
and optimization after call() method.
include_top: whether to include the fully-connected layer at the top of the
network.
pooling: Optional pooling mode for feature extraction when `include_top`
is `False`.
- `None` means that the output of the model will be the 4D tensor
output of the last convolutional layer.
- `avg` means that global average pooling will be applied to the output of
the last convolutional layer, and thus the output of the model will be
a 2D tensor.
- `max` means that global max pooling will be applied.
classes: optional number of classes to classify images into, only to be
specified if `include_top` is True.
Raises:
ValueError: in case of invalid argument for data_format.
"""
def __init__(self,
data_format,
name=None,
trainable=True,
include_top=True,
pooling=None,
classes=1000):
super(ResNet50, self).__init__(name='')
valid_channel_values = ('channels_first', 'channels_last')
if data_format not in valid_channel_values:
raise ValueError('Unknown data_format: %s. Valid values: %s' %
(data_format, valid_channel_values))
self.include_top = include_top
def conv_block(filters, stage, block, strides=(2, 2)):
l = _ConvBlock(
3,
filters,
stage=stage,
block=block,
data_format=data_format,
strides=strides)
return self.track_layer(l)
def id_block(filters, stage, block):
l = _IdentityBlock(
3, filters, stage=stage, block=block, data_format=data_format)
return self.track_layer(l)
self.conv1 = self.track_layer(
tf.layers.Conv2D(
64, (7, 7),
strides=(2, 2),
data_format=data_format,
padding='same',
name='conv1'))
bn_axis = 1 if data_format == 'channels_first' else 3
self.bn_conv1 = self.track_layer(
tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1'))
self.max_pool = self.track_layer(
tf.layers.MaxPooling2D((3, 3), strides=(2, 2), data_format=data_format))
self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1))
self.l2b = id_block([64, 64, 256], stage=2, block='b')
self.l2c = id_block([64, 64, 256], stage=2, block='c')
self.l3a = conv_block([128, 128, 512], stage=3, block='a')
self.l3b = id_block([128, 128, 512], stage=3, block='b')
self.l3c = id_block([128, 128, 512], stage=3, block='c')
self.l3d = id_block([128, 128, 512], stage=3, block='d')
self.l4a = conv_block([256, 256, 1024], stage=4, block='a')
self.l4b = id_block([256, 256, 1024], stage=4, block='b')
self.l4c = id_block([256, 256, 1024], stage=4, block='c')
self.l4d = id_block([256, 256, 1024], stage=4, block='d')
self.l4e = id_block([256, 256, 1024], stage=4, block='e')
self.l4f = id_block([256, 256, 1024], stage=4, block='f')
self.l5a = conv_block([512, 512, 2048], stage=5, block='a')
self.l5b = id_block([512, 512, 2048], stage=5, block='b')
self.l5c = id_block([512, 512, 2048], stage=5, block='c')
self.avg_pool = self.track_layer(
tf.layers.AveragePooling2D(
(7, 7), strides=(7, 7), data_format=data_format))
if self.include_top:
self.fc1000 = self.track_layer(
tf.layers.Dense(classes, name='fc1000'))
else:
reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3]
reduction_indices = tf.constant(reduction_indices)
if pooling == 'avg':
self.global_pooling = functools.partial(
tf.reduce_mean,
reduction_indices=reduction_indices,
keep_dims=False)
elif pooling == 'max':
self.global_pooling = functools.partial(
tf.reduce_max, reduction_indices=reduction_indices, keep_dims=False)
else:
self.global_pooling = None
def call(self, input_tensor, training=False):
x = self.conv1(input_tensor)
x = self.bn_conv1(x, training=training)
x = tf.nn.relu(x)
x = self.max_pool(x)
x = self.l2a(x, training=training)
x = self.l2b(x, training=training)
x = self.l2c(x, training=training)
x = self.l3a(x, training=training)
x = self.l3b(x, training=training)
x = self.l3c(x, training=training)
x = self.l3d(x, training=training)
x = self.l4a(x, training=training)
x = self.l4b(x, training=training)
x = self.l4c(x, training=training)
x = self.l4d(x, training=training)
x = self.l4e(x, training=training)
x = self.l4f(x, training=training)
x = self.l5a(x, training=training)
x = self.l5b(x, training=training)
x = self.l5c(x, training=training)
x = self.avg_pool(x)
if self.include_top:
return self.fc1000(tf.layers.flatten(x))
elif self.global_pooling:
return self.global_pooling(x)
else:
return x

View File

@ -0,0 +1,163 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests and benchmarks for ResNet50 under graph execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
import time
import numpy as np
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.resnet50 import resnet50
from tensorflow.contrib.summary import summary_test_util
def data_format():
return 'channels_first' if tf.test.is_gpu_available() else 'channels_last'
def image_shape(batch_size):
if data_format() == 'channels_first':
return [batch_size, 3, 224, 224]
return [batch_size, 224, 224, 3]
def random_batch(batch_size):
images = np.random.rand(*image_shape(batch_size)).astype(np.float32)
num_classes = 1000
labels = np.random.randint(
low=0, high=num_classes, size=[batch_size]).astype(np.int32)
one_hot = np.zeros((batch_size, num_classes)).astype(np.float32)
one_hot[np.arange(batch_size), labels] = 1.
return images, one_hot
class ResNet50GraphTest(tf.test.TestCase):
def testApply(self):
batch_size = 64
with tf.Graph().as_default():
images = tf.placeholder(tf.float32, image_shape(None))
model = resnet50.ResNet50(data_format())
predictions = model(images)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
np_images, _ = random_batch(batch_size)
out = sess.run(predictions, feed_dict={images: np_images})
self.assertAllEqual([64, 1000], out.shape)
def testTrainWithSummary(self):
with tf.Graph().as_default():
images = tf.placeholder(tf.float32, image_shape(None), name='images')
labels = tf.placeholder(tf.float32, [None, 1000], name='labels')
tf.train.get_or_create_global_step()
logdir = tempfile.mkdtemp()
with tf.contrib.summary.always_record_summaries():
with tf.contrib.summary.create_summary_file_writer(
logdir, max_queue=0,
name='t0').as_default():
model = resnet50.ResNet50(data_format())
logits = model(images, training=True)
loss = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels)
tf.contrib.summary.scalar(name='loss', tensor=loss)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss)
init = tf.global_variables_initializer()
self.assertEqual(321, len(tf.global_variables()))
batch_size = 32
with tf.Session() as sess:
sess.run(init)
sess.run(tf.contrib.summary.summary_writer_initializer_op())
np_images, np_labels = random_batch(batch_size)
sess.run([train_op, tf.contrib.summary.all_summary_ops()],
feed_dict={images: np_images, labels: np_labels})
events = summary_test_util.events_from_file(logdir)
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'loss')
class ResNet50Benchmarks(tf.test.Benchmark):
def _report(self, label, start, num_iters, batch_size):
avg_time = (time.time() - start) / num_iters
dev = 'gpu' if tf.test.is_gpu_available() else 'cpu'
name = 'graph_%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format())
extras = {'examples_per_sec': batch_size / avg_time}
self.report_benchmark(
iters=num_iters, wall_time=avg_time, name=name, extras=extras)
def benchmark_graph_apply(self):
with tf.Graph().as_default():
images = tf.placeholder(tf.float32, image_shape(None))
model = resnet50.ResNet50(data_format())
predictions = model(images)
init = tf.global_variables_initializer()
batch_size = 64
with tf.Session() as sess:
sess.run(init)
np_images, _ = random_batch(batch_size)
num_burn, num_iters = (3, 30)
for _ in range(num_burn):
sess.run(predictions, feed_dict={images: np_images})
start = time.time()
for _ in range(num_iters):
# Comparison with the eager execution benchmark in resnet50_test.py
# isn't entirely fair as the time here includes the cost of copying
# the feeds from CPU memory to GPU.
sess.run(predictions, feed_dict={images: np_images})
self._report('apply', start, num_iters, batch_size)
def benchmark_graph_train(self):
for batch_size in [16, 32, 64]:
with tf.Graph().as_default():
np_images, np_labels = random_batch(batch_size)
dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat()
(images, labels) = dataset.make_one_shot_iterator().get_next()
model = resnet50.ResNet50(data_format())
logits = model(images, training=True)
loss = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = optimizer.minimize(loss)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
(num_burn, num_iters) = (5, 10)
for _ in range(num_burn):
sess.run(train_op)
start = time.time()
for _ in range(num_iters):
sess.run(train_op)
self._report('train', start, num_iters, batch_size)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,234 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests and benchmarks for the ResNet50 model, executed eagerly."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gc
import tempfile
import time
import tensorflow as tf
import tensorflow.contrib.eager as tfe
from tensorflow.contrib.eager.python.examples.resnet50 import resnet50
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.client import device_lib
def device_and_data_format():
return ('/gpu:0', 'channels_first') if tfe.num_gpus() else ('/cpu:0',
'channels_last')
def random_batch(batch_size):
_, data_format = device_and_data_format()
shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3)
shape = (batch_size,) + shape
num_classes = 1000
images = tf.random_uniform(shape)
labels = tf.random_uniform(
[batch_size], minval=0, maxval=num_classes, dtype=tf.int32)
one_hot = tf.one_hot(labels, num_classes)
return images, one_hot
def train_one_step(model, images, labels, optimizer):
def model_loss():
logits = model(images, training=True)
loss = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels)
tf.contrib.summary.scalar(name='loss', tensor=loss)
return loss
optimizer.minimize(model_loss)
class ResNet50Test(tf.test.TestCase):
def test_apply(self):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format)
with tf.device(device):
images, _ = random_batch(2)
output = model(images)
self.assertEqual((2, 1000), output.shape)
def test_apply_no_top(self):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False)
with tf.device(device):
images, _ = random_batch(2)
output = model(images)
output_shape = ((2, 2048, 1, 1)
if data_format == 'channels_first' else (2, 1, 1, 2048))
self.assertEqual(output_shape, output.shape)
def test_apply_with_pooling(self):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False, pooling='avg')
with tf.device(device):
images, _ = random_batch(2)
output = model(images)
self.assertEqual((2, 2048), output.shape)
def test_train(self):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format)
tf.train.get_or_create_global_step()
logdir = tempfile.mkdtemp()
with tf.contrib.summary.create_summary_file_writer(
logdir, max_queue=0,
name='t0').as_default(), tf.contrib.summary.always_record_summaries():
with tf.device(device):
optimizer = tf.train.GradientDescentOptimizer(0.1)
images, labels = random_batch(2)
train_one_step(model, images, labels, optimizer)
self.assertEqual(320, len(model.variables))
events = summary_test_util.events_from_file(logdir)
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'loss')
def test_no_garbage(self):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format)
optimizer = tf.train.GradientDescentOptimizer(0.1)
with tf.device(device):
images, labels = random_batch(2)
gc.disable()
# Warm up. Note that this first run does create significant amounts of
# garbage to be collected. The hope is that this is a build-only effect,
# and a subsequent training loop will create nothing which needs to be
# collected.
train_one_step(model, images, labels, optimizer)
gc.collect()
previous_gc_debug_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
for _ in range(2):
# Run twice to ensure that garbage that is created on the first
# iteration is no longer accessible.
train_one_step(model, images, labels, optimizer)
gc.collect()
# There should be no garbage requiring collection.
self.assertEqual(0, len(gc.garbage))
gc.set_debug(previous_gc_debug_flags)
gc.enable()
class MockIterator(object):
def __init__(self, tensors):
self._tensors = [tf.identity(x) for x in tensors]
def next(self):
return self._tensors
class ResNet50Benchmarks(tf.test.Benchmark):
def _train_batch_sizes(self):
"""Choose batch sizes based on GPU capability."""
for device in device_lib.list_local_devices():
if 'GPU:0' in device.name:
# Avoid OOM errors with larger batch sizes, which seem to cause errors
# later on even if caught.
#
# TODO(allenl): Base this on device memory; memory limit information
# during the test seems to exclude the amount TensorFlow has allocated,
# which isn't useful.
if 'K20' in device.physical_device_desc:
return (16,)
if 'P100' in device.physical_device_desc:
return (16, 32, 64)
return (16, 32)
def _report(self, label, start, num_iters, device, batch_size, data_format):
avg_time = (time.time() - start) / num_iters
dev = 'cpu' if 'cpu' in device else 'gpu'
name = '%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format)
extras = {'examples_per_sec': batch_size / avg_time}
self.report_benchmark(
iters=num_iters, wall_time=avg_time, name=name, extras=extras)
def _force_gpu_sync(self):
# If this function is called in the context of a GPU device
# (e.g., inside a 'with tf.device("/gpu:0")' block)
# then this will force a copy from CPU->GPU->CPU, which forces
# a sync. This is a roundabout way, yes.
tf.constant(1.).cpu()
def benchmark_eager_apply(self):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format)
batch_size = 64
num_burn = 5
num_iters = 30
with tf.device(device):
images, _ = random_batch(batch_size)
for _ in xrange(num_burn):
model(images).cpu()
gc.collect()
start = time.time()
for _ in xrange(num_iters):
model(images).cpu()
self._report('eager_apply', start, num_iters, device, batch_size,
data_format)
def _benchmark_eager_train(self, label, make_iterator):
device, data_format = device_and_data_format()
for batch_size in self._train_batch_sizes():
(images, labels) = random_batch(batch_size)
num_burn = 3
num_iters = 10
model = resnet50.ResNet50(data_format)
optimizer = tf.train.GradientDescentOptimizer(0.1)
with tf.device(device):
iterator = make_iterator((images, labels))
for _ in xrange(num_burn):
(images, labels) = iterator.next()
train_one_step(model, images, labels, optimizer)
self._force_gpu_sync()
gc.collect()
start = time.time()
for _ in xrange(num_iters):
(images, labels) = iterator.next()
train_one_step(model, images, labels, optimizer)
self._force_gpu_sync()
self._report(label, start, num_iters, device, batch_size, data_format)
def benchmark_eager_train(self):
self._benchmark_eager_train('eager_train', MockIterator)
def benchmark_eager_train_datasets(self):
def make_iterator(tensors):
with tf.device('/device:CPU:0'):
ds = tf.data.Dataset.from_tensors(tensors).repeat()
return tfe.Iterator(ds)
self._benchmark_eager_train('eager_train_dataset', make_iterator)
if __name__ == '__main__':
tfe.enable_eager_execution()
tf.test.main()

View File

@ -0,0 +1,26 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_binary(
name = "rnn_colorbot",
srcs = ["rnn_colorbot.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/eager/python:tfe",
"@six_archive//:six",
],
)
cuda_py_test(
name = "rnn_colorbot_test",
srcs = ["rnn_colorbot_test.py"],
additional_deps = [
":rnn_colorbot",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow:tensorflow_py",
],
)

View File

@ -0,0 +1,26 @@
RNN Colorbot: An RNN that predicts colors using eager execution.
To train and generate colors, run:
```
python rnn_colorbot.py
```
This example shows how to:
1. read, process, (one-hot) encode, and pad text data via the
Datasets API;
2. build a trainable model;
3. implement a multi-layer RNN using Python control flow
constructs (e.g., a for loop);
4. train a model using an iterative gradient-based method; and
5. log training and evaluation loss for consumption by TensorBoard
(to view summaries, use: tensorboard --log_dir=<dir>/summaries).
The data used in this example is licensed under the Creative Commons
Attribution-ShareAlike License and is available at
https://en.wikipedia.org/wiki/List_of_colors:_A-F
https://en.wikipedia.org/wiki/List_of_colors:_G-M
https://en.wikipedia.org/wiki/List_of_colors:_N-Z
This example was adapted from
https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot

View File

@ -0,0 +1,338 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""TensorFlow Eager Execution Example: RNN Colorbot.
This example builds, trains, and evaluates a multi-layer RNN that can be
run with eager execution enabled. The RNN is trained to map color names to
their RGB values: it takes as input a one-hot encoded character sequence and
outputs a three-tuple (R, G, B) (scaled by 1/255).
For example, say we'd like the RNN Colorbot to generate the RGB values for the
color white. To represent our query in a form that the Colorbot could
understand, we would create a sequence of five 256-long vectors encoding the
ASCII values of the characters in "white". The first vector in our sequence
would be 0 everywhere except for the ord("w")-th position, where it would be
1, the second vector would be 0 everywhere except for the
ord("h")-th position, where it would be 1, and similarly for the remaining three
vectors. We refer to such indicator vectors as "one-hot encodings" of
characters. After consuming these vectors, a well-trained Colorbot would output
the three tuple (1, 1, 1), since the RGB values for white are (255, 255, 255).
We are of course free to ask the colorbot to generate colors for any string we'd
like, such as "steel gray," "tensorflow orange," or "green apple," though
your mileage may vary as your queries increase in creativity.
This example shows how to:
1. read, process, (one-hot) encode, and pad text data via the
Datasets API;
2. build a trainable model;
3. implement a multi-layer RNN using Python control flow
constructs (e.g., a for loop);
4. train a model using an iterative gradient-based method; and
The data used in this example is licensed under the Creative Commons
Attribution-ShareAlike License and is available at
https://en.wikipedia.org/wiki/List_of_colors:_A-F
https://en.wikipedia.org/wiki/List_of_colors:_G-M
https://en.wikipedia.org/wiki/List_of_colors:_N-Z
This example was adapted from
https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import functools
import os
import sys
import time
import six
import tensorflow as tf
from tensorflow.contrib.eager.python import tfe
from tensorflow.python.eager import context
try:
import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False
def parse(line):
"""Parse a line from the colors dataset."""
# Each line of the dataset is comma-separated and formatted as
# color_name, r, g, b
# so `items` is a list [color_name, r, g, b].
items = tf.string_split([line], ",").values
rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.
# Represent the color name as a one-hot encoded character sequence.
color_name = items[0]
chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)
# The sequence length is needed by our RNN.
length = tf.cast(tf.shape(chars)[0], dtype=tf.int64)
return rgb, chars, length
def load_dataset(data_dir, url, batch_size):
"""Loads the colors data at path into a PaddedDataset."""
# Downloads data at url into data_dir/basename(url). The dataset has a header
# row (color_name, r, g, b) followed by comma-separated lines.
path = tf.contrib.learn.datasets.base.maybe_download(
os.path.basename(url), data_dir, url)
# This chain of commands loads our data by:
# 1. skipping the header; (.skip(1))
# 2. parsing the subsequent lines; (.map(parse))
# 3. shuffling the data; (.shuffle(...))
# 3. grouping the data into padded batches (.padded_batch(...)).
dataset = tf.data.TextLineDataset(path).skip(1).map(parse).shuffle(
buffer_size=10000).padded_batch(
batch_size, padded_shapes=([None], [None, None], []))
return dataset
# pylint: disable=not-callable
class RNNColorbot(tfe.Network):
"""Multi-layer (LSTM) RNN that regresses on real-valued vector labels.
"""
def __init__(self, rnn_cell_sizes, label_dimension, keep_prob):
"""Constructs an RNNColorbot.
Args:
rnn_cell_sizes: list of integers denoting the size of each LSTM cell in
the RNN; rnn_cell_sizes[i] is the size of the i-th layer cell
label_dimension: the length of the labels on which to regress
keep_prob: (1 - dropout probability); dropout is applied to the outputs of
each LSTM layer
"""
super(RNNColorbot, self).__init__(name="")
self.label_dimension = label_dimension
self.keep_prob = keep_prob
# Note the calls to `track_layer` below; these calls register the layers as
# network components that house trainable variables.
self.cells = [
self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(size))
for size in rnn_cell_sizes
]
self.relu = self.track_layer(
tf.layers.Dense(label_dimension, activation=tf.nn.relu, name="relu"))
def call(self, chars, sequence_length, training=False):
"""Implements the RNN logic and prediction generation.
Args:
chars: a Tensor of dimension [batch_size, time_steps, 256] holding a
batch of one-hot encoded color names
sequence_length: a Tensor of dimension [batch_size] holding the length
of each character sequence (i.e., color name)
training: whether the invocation is happening during training
Returns:
A tensor of dimension [batch_size, label_dimension] that is produced by
passing chars through a multi-layer RNN and applying a ReLU to the final
hidden state.
"""
# Transpose the first and second dimensions so that chars is of shape
# [time_steps, batch_size, dimension].
chars = tf.transpose(chars, [1, 0, 2])
# The outer loop cycles through the layers of the RNN; the inner loop
# executes the time steps for a particular layer.
batch_size = int(chars.shape[1])
for l in range(len(self.cells)):
cell = self.cells[l]
outputs = []
state = cell.zero_state(batch_size, tf.float32)
# Unstack the inputs to obtain a list of batches, one for each time step.
chars = tf.unstack(chars, axis=0)
for ch in chars:
output, state = cell(ch, state)
outputs.append(output)
# The outputs of this layer are the inputs of the subsequent layer.
chars = tf.stack(outputs, axis=0)
if training:
chars = tf.nn.dropout(chars, self.keep_prob)
# Extract the correct output (i.e., hidden state) for each example. All the
# character sequences in this batch were padded to the same fixed length so
# that they could be easily fed through the above RNN loop. The
# `sequence_length` vector tells us the true lengths of the character
# sequences, letting us obtain for each sequence the hidden state that was
# generated by its non-padding characters.
batch_range = [i for i in range(batch_size)]
indices = tf.stack([sequence_length - 1, batch_range], axis=1)
hidden_states = tf.gather_nd(chars, indices)
return self.relu(hidden_states)
def loss(labels, predictions):
"""Computes mean squared loss."""
return tf.reduce_mean(tf.square(predictions - labels))
def test(model, eval_data):
"""Computes the average loss on eval_data, which should be a Dataset."""
avg_loss = tfe.metrics.Mean("loss")
for (labels, chars, sequence_length) in tfe.Iterator(eval_data):
predictions = model(chars, sequence_length, training=False)
avg_loss(loss(labels, predictions))
print("eval/loss: %.6f\n" % avg_loss.result())
with tf.contrib.summary.always_record_summaries():
tf.contrib.summary.scalar("loss", avg_loss.result())
def train_one_epoch(model, optimizer, train_data, log_interval=10):
"""Trains model on train_data using optimizer."""
tf.train.get_or_create_global_step()
def model_loss(labels, chars, sequence_length):
predictions = model(chars, sequence_length, training=True)
loss_value = loss(labels, predictions)
tf.contrib.summary.scalar("loss", loss_value)
return loss_value
for (batch, (labels, chars, sequence_length)) in enumerate(
tfe.Iterator(train_data)):
with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval):
batch_model_loss = functools.partial(model_loss, labels, chars,
sequence_length)
optimizer.minimize(
batch_model_loss, global_step=tf.train.get_global_step())
if log_interval and batch % log_interval == 0:
print("train/batch #%d\tloss: %.6f" % (batch, batch_model_loss()))
SOURCE_TRAIN_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv"
SOURCE_TEST_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv"
def main(_):
data_dir = os.path.join(FLAGS.dir, "data")
train_data = load_dataset(
data_dir=data_dir, url=SOURCE_TRAIN_URL, batch_size=FLAGS.batch_size)
eval_data = load_dataset(
data_dir=data_dir, url=SOURCE_TEST_URL, batch_size=FLAGS.batch_size)
model = RNNColorbot(
rnn_cell_sizes=FLAGS.rnn_cell_sizes,
label_dimension=3,
keep_prob=FLAGS.keep_probability)
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
if FLAGS.no_gpu or tfe.num_gpus() <= 0:
print(tfe.num_gpus())
device = "/cpu:0"
else:
device = "/gpu:0"
print("Using device %s." % device)
log_dir = os.path.join(FLAGS.dir, "summaries")
tf.gfile.MakeDirs(log_dir)
train_summary_writer = tf.contrib.summary.create_summary_file_writer(
os.path.join(log_dir, "train"), flush_secs=10)
test_summary_writer = tf.contrib.summary.create_summary_file_writer(
os.path.join(log_dir, "eval"), flush_secs=10, name="eval")
with tf.device(device):
for epoch in range(FLAGS.num_epochs):
start = time.time()
with train_summary_writer.as_default():
train_one_epoch(model, optimizer, train_data, FLAGS.log_interval)
end = time.time()
print("train/time for epoch #%d: %.2f" % (epoch, end - start))
with test_summary_writer.as_default():
test(model, eval_data)
print("Colorbot is ready to generate colors!")
while True:
try:
color_name = six.moves.input(
"Give me a color name (or press enter to exit): ")
except EOFError:
return
if not color_name:
return
_, chars, length = parse(color_name)
with tf.device(device):
(chars, length) = (tf.identity(chars), tf.identity(length))
chars = tf.expand_dims(chars, 0)
length = tf.expand_dims(length, 0)
preds = tf.unstack(model(chars, length, training=False)[0])
# Predictions cannot be negative, as they are generated by a ReLU layer;
# they may, however, be greater than 1.
clipped_preds = tuple(min(float(p), 1.0) for p in preds)
rgb = tuple(int(p * 255) for p in clipped_preds)
print("rgb:", rgb)
data = [[clipped_preds]]
if HAS_MATPLOTLIB:
plt.imshow(data)
plt.title(color_name)
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dir",
type=str,
default="/tmp/rnn_colorbot/",
help="Directory to download data files and save logs.")
parser.add_argument(
"--log_interval",
type=int,
default=10,
metavar="N",
help="Log training loss every log_interval batches.")
parser.add_argument(
"--num_epochs", type=int, default=20, help="Number of epochs to train.")
parser.add_argument(
"--rnn_cell_sizes",
type=int,
nargs="+",
default=[256, 128],
help="List of sizes for each layer of the RNN.")
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="Batch size for training and eval.")
parser.add_argument(
"--keep_probability",
type=float,
default=0.5,
help="Keep probability for dropout between layers.")
parser.add_argument(
"--learning_rate",
type=float,
default=0.01,
help="Learning rate to be used during training.")
parser.add_argument(
"--no_gpu",
action="store_true",
default=False,
help="Disables GPU usage even if a GPU is available.")
FLAGS, unparsed = parser.parse_known_args()
tfe.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,71 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.eager.python import tfe
from tensorflow.contrib.eager.python.examples.rnn_colorbot import rnn_colorbot
LABEL_DIMENSION = 5
def device():
return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0"
def random_dataset():
batch_size = 64
time_steps = 10
alphabet = 50
chars = tf.one_hot(
tf.random_uniform(
[batch_size, time_steps], minval=0, maxval=alphabet, dtype=tf.int32),
alphabet)
sequence_length = tf.constant(
[time_steps for _ in range(batch_size)], dtype=tf.int64)
labels = tf.random_normal([batch_size, LABEL_DIMENSION])
return tf.data.Dataset.from_tensors((labels, chars, sequence_length))
class RNNColorbotTest(tf.test.TestCase):
def testTrainOneEpoch(self):
model = rnn_colorbot.RNNColorbot(
rnn_cell_sizes=[256, 128, 64],
label_dimension=LABEL_DIMENSION,
keep_prob=1.0)
optimizer = tf.train.AdamOptimizer(learning_rate=.01)
dataset = random_dataset()
with tf.device(device()):
rnn_colorbot.train_one_epoch(model, optimizer, dataset)
def testTest(self):
model = rnn_colorbot.RNNColorbot(
rnn_cell_sizes=[256],
label_dimension=LABEL_DIMENSION,
keep_prob=1.0)
dataset = random_dataset()
with tf.device(device()):
rnn_colorbot.test(model, dataset)
if __name__ == "__main__":
tfe.enable_eager_execution()
tf.test.main()

View File

@ -0,0 +1,35 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_binary(
name = "rnn_ptb",
srcs = ["rnn_ptb.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/eager/python:tfe",
],
)
cuda_py_test(
name = "rnn_ptb_test",
srcs = ["rnn_ptb_test.py"],
additional_deps = [
":rnn_ptb",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow:tensorflow_py",
],
)
cuda_py_test(
name = "rnn_ptb_graph_test",
srcs = ["rnn_ptb_graph_test.py"],
additional_deps = [
":rnn_ptb",
"//third_party/py/numpy",
"//tensorflow:tensorflow_py",
],
)

View File

@ -0,0 +1,42 @@
Recurrent Neural Network model.
Implements a language modeling network described in
https://www.tensorflow.org/tutorials/recurrent
that is compatible with (and idiomatic for) eager execution.
To run:
- Download and extract the Penn Treebank dataset from
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
```sh
tar xvzf simple-examples.tgz -C /tmp
```
- Run: `python rnn_ptb.py --data-dir=/tmp/simple-examples/data`
Benchmarks (using synthetic data):
```
# Using eager execution
bazel run -c opt --config=cuda :rnn_ptb_test -- --benchmarks=.
# Using graph execution
bazel run -c opt --config=cuda :rnn_ptb_graph_test -- --benchmarks=.
```
(Or remove the `--config=cuda` flag for running on CPU instead of GPU).
On October 31, 2017, the benchmarks demostrated slightly better performance
(3-6%) for graph execution over eager execution for this particular model when
using a single NVIDIA Titan X (Pascal) GPU on a host with an Intel Xeon E5-1650
CPU @ 3.50GHz and a batch size of 32.
| Benchmark name | examples/second |
| ------------------------------------ | --------------- |
| eager_cudnn_train_large_gpu_batch_20 | 938 |
| graph_cudnn_train_large_gpu_batch_20 | 971 |
| eager_cudnn_train_small_gpu_batch_20 | 2433 |
| graph_cudnn_train_small_gpu_batch_20 | 2585 |

View File

@ -0,0 +1,348 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Penn Treebank RNN model definition compatible with eager execution.
Model similar to
https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb
Usage: python ./rnn_ptb.py --data-path=<path_to_dataset>
Penn Treebank (PTB) dataset from:
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
"""
import argparse
import os
import sys
import time
import numpy as np
import tensorflow as tf
from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn
from tensorflow.contrib.eager.python import tfe
class RNN(tfe.Network):
"""A static RNN.
Similar to tf.nn.static_rnn, implemented as a tf.layer.Layer.
"""
def __init__(self, hidden_dim, num_layers, keep_ratio):
super(RNN, self).__init__()
self.keep_ratio = keep_ratio
for _ in range(num_layers):
self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim))
def call(self, input_seq, training):
batch_size = int(input_seq.shape[1])
for c in self.layers:
state = c.zero_state(batch_size, tf.float32)
outputs = []
input_seq = tf.unstack(input_seq, num=int(input_seq.shape[0]), axis=0)
for inp in input_seq:
output, state = c(inp, state)
outputs.append(output)
input_seq = tf.stack(outputs, axis=0)
if training:
input_seq = tf.nn.dropout(input_seq, self.keep_ratio)
return input_seq, None
class Embedding(tf.layers.Layer):
"""An Embedding layer."""
def __init__(self, vocab_size, embedding_dim, **kwargs):
super(Embedding, self).__init__(**kwargs)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
def build(self, _):
self.embedding = self.add_variable(
"embedding_kernel",
shape=[self.vocab_size, self.embedding_dim],
dtype=tf.float32,
initializer=tf.random_uniform_initializer(-0.1, 0.1),
trainable=True)
def call(self, x):
return tf.nn.embedding_lookup(self.embedding, x)
class PTBModel(tfe.Network):
"""LSTM for word language modelling.
Model described in:
(Zaremba, et. al.) Recurrent Neural Network Regularization
http://arxiv.org/abs/1409.2329
See also:
https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb
"""
def __init__(self,
vocab_size,
embedding_dim,
hidden_dim,
num_layers,
dropout_ratio,
use_cudnn_rnn=True):
super(PTBModel, self).__init__()
self.keep_ratio = 1 - dropout_ratio
self.use_cudnn_rnn = use_cudnn_rnn
self.embedding = self.track_layer(Embedding(vocab_size, embedding_dim))
if self.use_cudnn_rnn:
self.rnn = cudnn_rnn.CudnnLSTM(
num_layers, hidden_dim, dropout=dropout_ratio)
else:
self.rnn = RNN(hidden_dim, num_layers, self.keep_ratio)
self.track_layer(self.rnn)
self.linear = self.track_layer(
tf.layers.Dense(
vocab_size,
kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1)))
self._output_shape = [-1, embedding_dim]
def call(self, input_seq, training):
"""Run the forward pass of PTBModel.
Args:
input_seq: [length, batch] shape int64 tensor.
training: Is this a training call.
Returns:
outputs tensors of inference.
"""
y = self.embedding(input_seq)
if training:
y = tf.nn.dropout(y, self.keep_ratio)
y, _ = self.rnn(y, training=training)
return self.linear(tf.reshape(y, self._output_shape))
def clip_gradients(grads_and_vars, clip_ratio):
gradients, variables = zip(*grads_and_vars)
clipped, _ = tf.clip_by_global_norm(gradients, clip_ratio)
return zip(clipped, variables)
def loss_fn(model, inputs, targets, training):
labels = tf.reshape(targets, [-1])
outputs = model(inputs, training)
return tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=outputs))
def _divide_into_batches(data, batch_size):
"""Convert a sequence to a batch of sequences."""
nbatch = data.shape[0] // batch_size
data = data[:nbatch * batch_size]
data = data.reshape(batch_size, -1).transpose()
return data
def _get_batch(data, i, seq_len):
slen = min(seq_len, data.shape[0] - 1 - i)
inputs = data[i:i + slen, :]
target = data[i + 1:i + 1 + slen, :]
return tf.constant(inputs), tf.constant(target)
def evaluate(model, data):
"""evaluate an epoch."""
total_loss = 0.0
total_batches = 0
start = time.time()
for _, i in enumerate(range(0, data.shape[0] - 1, FLAGS.seq_len)):
inp, target = _get_batch(data, i, FLAGS.seq_len)
loss = loss_fn(model, inp, target, training=False)
total_loss += loss.numpy()
total_batches += 1
time_in_ms = (time.time() - start) * 1000
sys.stderr.write("eval loss %.2f (eval took %d ms)\n" %
(total_loss / total_batches, time_in_ms))
return total_loss
def train(model, optimizer, train_data, sequence_length, clip_ratio):
"""training an epoch."""
def model_loss(inputs, targets):
return loss_fn(model, inputs, targets, training=True)
grads = tfe.implicit_gradients(model_loss)
total_time = 0
for batch, i in enumerate(range(0, train_data.shape[0] - 1, sequence_length)):
train_seq, train_target = _get_batch(train_data, i, sequence_length)
start = time.time()
optimizer.apply_gradients(
clip_gradients(grads(train_seq, train_target), clip_ratio))
total_time += (time.time() - start)
if batch % 10 == 0:
time_in_ms = (total_time * 1000) / (batch + 1)
sys.stderr.write("batch %d: training loss %.2f, avg step time %d ms\n" %
(batch, model_loss(train_seq, train_target).numpy(),
time_in_ms))
class Datasets(object):
"""Processed form of the Penn Treebank dataset."""
def __init__(self, path):
"""Load the Penn Treebank dataset.
Args:
path: Path to the data/ directory of the dataset from from Tomas Mikolov's
webpage - http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
"""
self.word2idx = {} # string -> integer id
self.idx2word = [] # integer id -> word string
# Files represented as a list of integer ids (as opposed to list of string
# words).
self.train = self.tokenize(os.path.join(path, "ptb.train.txt"))
self.valid = self.tokenize(os.path.join(path, "ptb.valid.txt"))
def vocab_size(self):
return len(self.idx2word)
def add(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
def tokenize(self, path):
"""Read text file in path and return a list of integer token ids."""
tokens = 0
with tf.gfile.Open(path, "r") as f:
for line in f:
words = line.split() + ["<eos>"]
tokens += len(words)
for word in words:
self.add(word)
# Tokenize file content
with tf.gfile.Open(path, "r") as f:
ids = np.zeros(tokens).astype(np.int64)
token = 0
for line in f:
words = line.split() + ["<eos>"]
for word in words:
ids[token] = self.word2idx[word]
token += 1
return ids
def small_model(use_cudnn_rnn):
"""Returns a PTBModel with a 'small' configuration."""
return PTBModel(
vocab_size=10000,
embedding_dim=200,
hidden_dim=200,
num_layers=2,
dropout_ratio=0.,
use_cudnn_rnn=use_cudnn_rnn)
def large_model(use_cudnn_rnn):
"""Returns a PTBModel with a 'large' configuration."""
return PTBModel(
vocab_size=10000,
embedding_dim=650,
hidden_dim=650,
num_layers=2,
dropout_ratio=0.5,
use_cudnn_rnn=use_cudnn_rnn)
def main(_):
tfe.enable_eager_execution()
if not FLAGS.data_path:
raise ValueError("Must specify --data-path")
corpus = Datasets(FLAGS.data_path)
train_data = _divide_into_batches(corpus.train, FLAGS.batch_size)
eval_data = _divide_into_batches(corpus.valid, 10)
have_gpu = tfe.num_gpus() > 0
use_cudnn_rnn = not FLAGS.no_use_cudnn_rnn and have_gpu
with tfe.restore_variables_on_create(
tf.train.latest_checkpoint(FLAGS.logdir)):
with tf.device("/device:GPU:0" if have_gpu else None):
# Make learning_rate a Variable so it can be included in the checkpoint
# and we can resume training with the last saved learning_rate.
learning_rate = tfe.Variable(20.0, name="learning_rate")
sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy())
model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim,
FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout,
use_cudnn_rnn)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
best_loss = None
for _ in range(FLAGS.epoch):
train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip)
eval_loss = evaluate(model, eval_data)
if not best_loss or eval_loss < best_loss:
if FLAGS.logdir:
tfe.Saver(model.trainable_weights + [learning_rate]).save(
os.path.join(FLAGS.logdir, "ckpt"))
best_loss = eval_loss
else:
learning_rate.assign(learning_rate / 4.0)
sys.stderr.write("eval_loss did not reduce in this epoch, "
"changing learning rate to %f for the next epoch\n" %
learning_rate.numpy())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-path",
type=str,
default="",
help="Data directory of the Penn Treebank dataset from "
"http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz")
parser.add_argument(
"--logdir", type=str, default="", help="Directory for checkpoint.")
parser.add_argument(
"--epoch", type=int, default=20, help="Number of epoches.")
parser.add_argument("--batch-size", type=int, default=20, help="Batch size.")
parser.add_argument(
"--seq-len", type=int, default=35, help="Sequence length.")
parser.add_argument(
"--embedding-dim", type=int, default=200, help="Embedding dimension.")
parser.add_argument(
"--hidden-dim", type=int, default=200, help="Hidden layer dimension.")
parser.add_argument(
"--num-layers", type=int, default=2, help="Number of RNN layers.")
parser.add_argument(
"--dropout", type=float, default=0.2, help="Drop out ratio.")
parser.add_argument(
"--clip", type=float, default=0.25, help="Gradient clipping ratio.")
parser.add_argument(
"--no-use-cudnn-rnn",
action="store_true",
default=False,
help="Disable the fast CuDNN RNN (when no gpu)")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,164 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for PTBModel used for graph construction."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gc
import time
import numpy as np
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.rnn_ptb import rnn_ptb
class PTBTest(tf.test.TestCase):
def testTrain(self):
batch_size = 20
sequence_length = 35
with tf.Graph().as_default(), tf.device(tf.test.gpu_device_name()):
inputs_ph = tf.placeholder(tf.int64, [sequence_length, batch_size],
"inputs")
labels_ph = tf.placeholder(tf.int64, [sequence_length, batch_size],
"labels")
inputs = np.ones(inputs_ph.shape.as_list(), dtype=np.int64)
labels = np.ones(labels_ph.shape.as_list(), dtype=np.int64)
model = rnn_ptb.small_model(tf.test.is_gpu_available())
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
loss = rnn_ptb.loss_fn(model, inputs_ph, labels_ph, training=True)
grads = rnn_ptb.clip_gradients(optimizer.compute_gradients(loss), 0.25)
train_op = optimizer.apply_gradients(grads)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(train_op, feed_dict={inputs_ph: inputs, labels_ph: labels})
sess.run(
[train_op, loss], feed_dict={
inputs_ph: inputs,
labels_ph: labels
})
class PTBBenchmark(tf.test.Benchmark):
BATCH_SIZE = 20
SEQ_LEN = 35
def _report(self, label, start, num_iters, device, batch_size):
wall_time = (time.time() - start) / num_iters
dev = "cpu" if "cpu" in device.lower() else "gpu"
name = "%s_%s_batch_%d" % (label, dev, batch_size)
examples_per_sec = batch_size / wall_time
self.report_benchmark(
iters=num_iters,
wall_time=wall_time,
name=name,
extras={
"examples_per_sec": examples_per_sec
})
def _benchmark_apply(self, label, model):
num_iters = 100
num_warmup = 10
dataset = tf.data.Dataset.from_tensors(
tf.ones(
[PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE],
dtype=tf.int64)).repeat(num_iters + num_warmup)
inputs = dataset.make_one_shot_iterator().get_next()
with tf.device(tf.test.gpu_device_name()):
outputs = model(inputs, training=True)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(num_warmup):
sess.run(outputs)
gc.collect()
start = time.time()
for _ in range(num_iters):
sess.run(outputs)
self._report(label, start, num_iters,
tf.test.gpu_device_name(), PTBBenchmark.BATCH_SIZE)
def benchmark_apply_small(self):
self._benchmark_apply("graph_apply_small", rnn_ptb.small_model(False))
def benchmark_apply_large(self):
self._benchmark_apply("graph_apply_large", rnn_ptb.large_model(False))
def benchmark_cudnn_apply_small(self):
if not tf.test.is_gpu_available():
return
self._benchmark_apply("graph_cudnn_apply_small", rnn_ptb.small_model(True))
def benchmark_cudnn_apply_large(self):
if not tf.test.is_gpu_available():
return
self._benchmark_apply("graph_cudnn_apply_large", rnn_ptb.large_model(True))
def _benchmark_train(self, label, model):
num_iters = 100
num_warmup = 10
dataset = tf.data.Dataset.from_tensors(
tf.ones(
[PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE],
dtype=tf.int64)).repeat(num_iters + num_warmup)
# inputs and labels have the same shape
dataset = tf.data.Dataset.zip((dataset, dataset))
(inputs, labels) = dataset.make_one_shot_iterator().get_next()
with tf.device(tf.test.gpu_device_name()):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
loss = rnn_ptb.loss_fn(model, inputs, labels, training=True)
grads = rnn_ptb.clip_gradients(optimizer.compute_gradients(loss), 0.25)
train_op = optimizer.apply_gradients(grads)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(num_warmup):
sess.run(train_op)
gc.collect()
start = time.time()
for _ in range(num_iters):
sess.run(train_op)
self._report(label, start, num_iters,
tf.test.gpu_device_name(), PTBBenchmark.BATCH_SIZE)
def benchmark_train_small(self):
self._benchmark_train("graph_train_small", rnn_ptb.small_model(False))
def benchmark_train_large(self):
self._benchmark_train("graph_train_large", rnn_ptb.large_model(False))
def benchmark_cudnn_train_small(self):
if not tf.test.is_gpu_available():
return
self._benchmark_train("graph_cudnn_train_small", rnn_ptb.small_model(True))
def benchmark_cudnn_train_large(self):
if not tf.test.is_gpu_available():
return
self._benchmark_train("graph_cudnn_train_large", rnn_ptb.large_model(True))
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,154 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for PTBModel with eager execution enabled."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gc
import time
import numpy as np
import tensorflow as tf
from tensorflow.contrib.eager.python import tfe
from tensorflow.contrib.eager.python.examples.rnn_ptb import rnn_ptb
def device():
return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0"
class PTBTest(tf.test.TestCase):
def testTrain(self):
model = rnn_ptb.small_model(tfe.num_gpus() > 0)
sequence_length = 35
data = np.ones([4 * sequence_length, 20], dtype=np.int64)
with tf.device(device()):
optimizer = tf.train.GradientDescentOptimizer(1.0)
# Train two epochs
rnn_ptb.train(model, optimizer, data, sequence_length, 0.25)
rnn_ptb.train(model, optimizer, data, sequence_length, 0.25)
def testApply(self):
model = rnn_ptb.small_model(tfe.num_gpus() > 0)
with tf.device(device()):
model(tf.ones([35, 20], dtype=tf.int64), training=False)
def force_gpu_sync():
if tfe.num_gpus():
tf.constant(1).gpu().cpu()
class PTBBenchmark(tf.test.Benchmark):
BATCH_SIZE = 20
SEQ_LEN = 35
def _report(self, label, start, num_iters, dev, batch_size):
wall_time = (time.time() - start) / num_iters
dev = "cpu" if "cpu" in dev.lower() else "gpu"
name = "%s_%s_batch_%d" % (label, dev, batch_size)
examples_per_sec = batch_size / wall_time
self.report_benchmark(
iters=num_iters,
wall_time=wall_time,
name=name,
extras={
"examples_per_sec": examples_per_sec
})
def _benchmark_apply(self, label, model):
with tf.device(device()):
sequence_batch = tf.ones(
[PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], dtype=tf.int64)
for _ in range(10): # Warmup
model(sequence_batch, training=False).cpu()
gc.collect()
start = time.time()
iters = 100
for _ in range(iters):
model(sequence_batch, training=False).cpu()
self._report(label, start, iters, device(), int(sequence_batch.shape[1]))
def benchmark_apply_small(self):
self._benchmark_apply("eager_apply_small", rnn_ptb.small_model(False))
def benchmark_apply_large(self):
self._benchmark_apply("eager_apply_large", rnn_ptb.large_model(False))
def benchmark_cudnn_apply_small(self):
if not tfe.num_gpus():
return
self._benchmark_apply("eager_cudnn_apply_small", rnn_ptb.small_model(True))
def benchmark_cudnn_apply_large(self):
if not tfe.num_gpus():
return
self._benchmark_apply("eager_cudnn_apply_large", rnn_ptb.large_model(True))
def _benchmark_train(self, label, model):
with tf.device(device()):
optimizer = tf.train.GradientDescentOptimizer(1.)
def model_loss(inputs, targets):
return rnn_ptb.loss_fn(model, inputs, targets, training=True)
grads = tfe.implicit_gradients(model_loss)
sequence_batch = tf.ones(
[PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], dtype=tf.int64)
def step():
optimizer.apply_gradients(
rnn_ptb.clip_gradients(grads(sequence_batch, sequence_batch), 0.25))
for _ in range(10): # Warmup
step()
force_gpu_sync()
gc.collect()
start = time.time()
iters = 100
for _ in range(iters):
step()
force_gpu_sync()
self._report(label, start, iters, device(), int(sequence_batch.shape[1]))
def benchmark_train_small(self):
self._benchmark_train("eager_train_small", rnn_ptb.small_model(False))
def benchmark_train_large(self):
self._benchmark_train("eager_train_large", rnn_ptb.large_model(False))
def benchmark_cudnn_train_small(self):
if not tfe.num_gpus():
return
self._benchmark_train("eager_cudnn_train_small", rnn_ptb.small_model(True))
def benchmark_cudnn_train_large(self):
if not tfe.num_gpus():
return
self._benchmark_train("eager_cudnn_train_large", rnn_ptb.large_model(True))
if __name__ == "__main__":
tfe.enable_eager_execution()
tf.test.main()

View File

@ -0,0 +1,899 @@
# TensorFlow Eager Execution
## What is this?
Eager execution is a feature that makes TensorFlow execute operations
immediately: concrete values are returned, instead of a computational graph to
be executed later.
As a result, enabling eager execution provides:
- A [NumPy](http://www.numpy.org/)-like library for numerical computation with
support for GPU acceleration and automatic differentiation.
- A flexible platform for machine learning research and experimentation.
Eager execution is under active development. This guide walks through an
alpha/preview release. In particular, not all TensorFlow APIs currently work
with eager execution enabled, and some models may be slow to execute, compared
to models defined without using eager execution.
## Installation
Eager execution is **not** included in the latest release (version 1.4) of
TensorFlow. To use it, you will need to [build TensorFlow from
source](https://www.tensorflow.org/install/install_sources) or install the
nightly builds.
For example, the nightly builds can be installed using `pip`:
- `pip install tf-nightly` (for CPU-only TensorFlow)
- `pip install tf-nightly-gpu` (for GPU-enabled TensorFlow)
Or using `docker`, with [Jupyter Notebook](http://jupyter.org/) support:
```sh
# For CPU-only TensorFlow
docker pull tensorflow/tensorflow:nightly
docker run -it -p 8888:8888 tensorflow/tensorflow:nightly
# For GPU-enabled TensorFlow:
# (Requires https://github.com/NVIDIA/nvidia-docker)
nvidia-docker pull tensorflow/tensorflow:nightly-gpu
nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu
```
## Getting Started
With TensorFlow installed, eager execution is enabled via a single call:
```python
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
```
Enabling eager execution changes how TensorFlow functions behave (in particular,
`Tensor` objects will reference concrete values instead of being symbolic
handles to nodes in a computational graph). As a result, eager execution should
be enabled at the beginning of a program and cannot be disabled afterwards in
the same program.
Code examples in the rest of this guide assume that eager execution has been
enabled.
## A library for numerical computation
A significant fraction of the [TensorFlow
API](https://www.tensorflow.org/api_docs/python/) consists of numerical
operations:
[arithmetic operations](https://www.tensorflow.org/api_docs/python/tf/matmul),
[matrix operations](https://www.tensorflow.org/api_docs/python/tf/matmul),
[linear algebra operations](https://www.tensorflow.org/api_docs/python/tf/linalg),
etc.
With eager execution enabled, these operations consume and return
multi-dimensional arrays as `Tensor` objects, similar to NumPy
[`ndarray`s](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ndarray.html).
For example:
```python
# Multiply two 2x2 matrices
x = tf.matmul([[1, 2],
[3, 4]],
[[4, 5],
[6, 7]])
# Add one to each element
# (tf.add supports broadcasting)
y = tf.add(x, 1)
# Create a random random 5x3 matrix
z = tf.random_uniform([5, 3])
print(x)
print(y)
print(z)
```
Output:
```
tf.Tensor(
[[16 19]
[36 43]], shape=(2, 2), dtype=int32)
tf.Tensor(
[[17 20]
[37 44]], shape=(2, 2), dtype=int32)
tf.Tensor(
[[ 0.25058532 0.0929395 0.54113817]
[ 0.3108716 0.93350542 0.84909797]
[ 0.53081679 0.12788558 0.01767385]
[ 0.29725885 0.33540785 0.83588314]
[ 0.38877153 0.39720535 0.78914213]], shape=(5, 3), dtype=float32)
```
For convenience, these operations can also be triggered via operator overloading
of the `Tensor` object. For example, the `+` operator is equivalent to `tf.add`,
`-` to `tf.subtract`, `*` to `tf.multiply`, etc.:
```python
x = (tf.ones([1], dtype=tf.float32) + 1) * 2 - 1
print(x)
```
Output:
```
tf.Tensor([ 3.], shape=(1,), dtype=float32)
```
### Converting to and from NumPy
The operations above automatically convert Python objects (like lists of
numbers) and NumPy arrays to `Tensor` objects. `Tensor` objects can also be used
as NumPy arrays by numpy operations.
```python
import numpy as np
x = tf.add(1, 1) # tf.Tensor with a value of 2
y = tf.add(np.array(1), np.array(1)) # tf.Tensor with a value of 2
z = np.multiply(x, y) # numpy.int64 with a value of 4
```
Alternatively, they can be explicitly converted using
[`tf.constant`](https://www.tensorflow.org/api_docs/python/tf/constant), as
shown in the next example.
Conversely, you can call the `numpy()` method of a `Tensor` object' to obtain
its NumPy `ndarray` value. For example:
```python
import numpy as np
np_x = np.array(2., dtype=np.float32)
x = tf.constant(np_x)
py_y = 3.
y = tf.constant(py_y)
z = x + y + 1
print(z)
print(z.numpy())
```
Output:
```
tf.Tensor(6.0, shape=(), dtype=float32)
6.0
```
### GPU acceleration
Many TensorFlow operations support GPU acceleration. With eager execution
enabled, [computation is *not* automatically
offloaded](https://www.tensorflow.org/tutorials/using_gpu) to GPUs. Instead, you
must explicitly specify when GPUs should be used.
The simplest way to do this is to enclose your computation in a `with
tf.device('/gpu:0')` block. Also of interest is the `tfe.num_gpus()` function,
which returns the number of available GPUs.
For example, consider this snippet to measure the time to multiply two 1000x1000
matrices on CPU:
```python
import time
def measure(x):
# The very first time a GPU is used by TensorFlow, it is initialized.
# So exclude the first run from timing.
tf.matmul(x, x)
start = time.time()
for i in range(10):
tf.matmul(x, x)
end = time.time()
return "Took %s seconds to multiply a %s matrix by itself 10 times" % (end - start, x.shape)
# Run on CPU:
with tf.device("/cpu:0"):
print("CPU: %s" % measure(tf.random_normal([1000, 1000])))
# If a GPU is available, run on GPU:
if tfe.num_gpus() > 0:
with tf.device("/gpu:0"):
print("GPU: %s" % measure(tf.random_normal([1000, 1000])))
```
Output (exact numbers will depend on the characteristics of the hardware):
```python
CPU: Took 0.145531892776 seconds to multiply a (1000, 1000) matrix by itself 10 times
GPU: Took 0.000458955764771 seconds to multiply a (1000, 1000) matrix by itself 10 times
```
Alternatively, methods on the `Tensor` object can be used to explicitly copy the
`Tensor` to a different device. Operations are typically executed on the device
on which the inputs are placed. For example:
```python
x = tf.random_normal([10, 10])
x_gpu0 = x.gpu()
x_cpu = x.cpu()
_ = tf.matmul(x_cpu, x_cpu) # Runs on CPU
_ = tf.matmul(x_gpu0, x_gpu0) # Runs on GPU:0
if tfe.num_gpus() > 1:
x_gpu1 = x.gpu(1)
_ = tf.matmul(x_gpu1, x_gpu1) # Runs on GPU:1
```
### Automatic Differentiation
[Automatic
differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) is
very useful when implementing many machine learning algorithms (e.g.,
[backpropagation](https://en.wikipedia.org/wiki/Backpropagation) for training
neural networks). For this purpose, TensorFlow eager execution provides an
[autograd](https://github.com/HIPS/autograd)-style API for automatic
differentiation. Specifically, the functions:
- `tfe.gradients_function(f)`: Returns a Python function that computes the
derivatives of the Python function `f` with respect to its arguments. `f`
must return a scalar value. When the returned function is invoked, it
returns a list of `Tensor` objects (one element for each argument of `f`).
- `tfe.value_and_gradients_function(f)`: Similar to `tfe.gradients_function`,
except that when the returned function is invoked, it returns the value of
`f` in addition to the list of derivatives of `f` with respect to its
arguments.
These functions naturally apply to higher order differentiation as well. For
example:
```python
def f(x):
return tf.multiply(x, x) # Or x * x
assert 9 == f(3.).numpy()
df = tfe.gradients_function(f)
assert 6 == df(3.)[0].numpy()
# Second order deriviative.
d2f = tfe.gradients_function(lambda x: df(x)[0])
assert 2 == d2f(3.)[0].numpy()
# Third order derivative.
d3f = tfe.gradients_function(lambda x : d2f(x)[0])
assert 0 == d3f(3.)[0].numpy()
```
These functions can be used to train models. For example, consider the following
simple linear regression model:
```python
def prediction(input, weight, bias):
return input * weight + bias
# A toy dataset of points around 3 * x + 2
NUM_EXAMPLES = 1000
training_inputs = tf.random_normal([NUM_EXAMPLES])
noise = tf.random_normal([NUM_EXAMPLES])
training_outputs = training_inputs * 3 + 2 + noise
# A loss function: Mean-squared error
def loss(weight, bias):
error = prediction(training_inputs, weight, bias) - training_outputs
return tf.reduce_mean(tf.square(error))
# Function that returns the the derivative of loss with respect to
# weight and bias
grad = tfe.gradients_function(loss)
# Train for 200 steps (starting from some random choice for W and B, on the same
# batch of data).
W = 5.
B = 10.
learning_rate = 0.01
print("Initial loss: %f" % loss(W, B).numpy())
for i in range(200):
(dW, dB) = grad(W, B)
W -= dW * learning_rate
B -= dB * learning_rate
if i % 20 == 0:
print("Loss at step %d: %f" % (i, loss(W, B).numpy()))
print("Final loss: %f" % loss(W, B).numpy())
print("W, B = %f, %f" % (W.numpy(), B.numpy()))
```
Output: (the exact numbers may vary depending on the randomness in noise)
```
Initial loss: 66.730003
Loss at step 0: 64.200096
Loss at step 20: 29.872814
Loss at step 40: 14.233772
Loss at step 60: 7.090570
Loss at step 80: 3.819887
Loss at step 100: 2.318821
Loss at step 120: 1.628385
Loss at step 140: 1.310142
Loss at step 160: 1.163167
Loss at step 180: 1.095162
Final loss: 1.064711
W, B = 3.094944, 2.161383
```
To utilize the GPU, place the code above within a `with tf.device("/gpu:0"):`
block. (However, this particular model, with only two floating point parameters,
is unlikely to benefit from GPU acceleration.)
### Customizing gradients
One may want to define custom gradients for an operation, or for a function.
This may be useful for multiple reasons, including providing a more efficient
or more [numerically stable](https://en.wikipedia.org/wiki/Numerical_stability)
gradient for a sequence of operations.
For example, consider the function `log(1 + e^x)`, which commonly occurs in the
computation of cross entropy and log likelihoods.
```python
def log1pexp(x):
 return tf.log(1 + tf.exp(x))
grad_log1pexp = tfe.gradients_function(log1pexp)
# Works fine at x = 0.
assert 0.5 == float(grad_log1pexp(0.)[0])
# Returns a `nan` at x = 100 due to numerical instability.
import math
assert math.isnan(float(grad_log1pexp(100.)[0]))
```
We can define a custom gradient for the above function that analytically
simplifies the gradient expression.
```python
@tfe.custom_gradient
def log1pexp(x):
 e = tf.exp(x)
 def grad(dy):
   return dy * (1 - 1 / (1 + e))
 return tf.log(1 + e), grad
grad_log1pexp = tfe.gradients_function(log1pexp)
# Works as before at x = 0.
assert 0.5 == float(grad_log1pexp(0.)[0])
# But now works at x = 100 as well.
assert 1.0 == float(grad_log1pexp(100.)[0])
```
Also notice how the gradient function implementation reuses an expression
(`tf.exp(x)`) computed during the forward pass, hence making the gradient
computation more efficient by avoiding redundant computation.
## Building and training models
In practice, your computation may have many parameters to be optimized (by
computing derivatives). Encapsulating them into re-usable classes/objects
makes the code easier to follow than writing a single top-level function with
many arguments.
In fact, eager execution encourages use of the [Keras](https://keras.io)-style
"Layer" classes in the
[`tf.layers`](https://www.tensorflow.org/versions/master/api_docs/python/tf/layers)
module.
Furthermore, you may want to apply more sophisticated techniques to compute
parameter updates, such as those in
[`tf.train.Optimizer`](https://www.tensorflow.org/api_guides/python/train#Optimizers)
implementations.
This next section walks through using the same `Optimizer` and `Layer` APIs used
to build trainable TensorFlow graphs in an environment where eager execution is
enabled.
### Variables and Optimizers
`tfe.Variable` objects store mutable `Tensor` values that can be accessed during
training, making automatic differentiation easier. In particular, parameters of
a model can be encapsulated in Python classes as variables.
`tfe.gradients_function(f)` introduced earlier computes the derivatives of `f`
with respect to its arguments. However, it requires all parameters of interest
to be arguments of `f`, which becomes cumbersome when `f` depends on a large
number of trainable parameters.
`tfe.implicit_gradients` is an alternative function with some useful properties:
- It computes the derivatives of `f` with respect to all the `tfe.Variable`s
used by `f`.
- When the returned function is invoked, it returns a list of
(gradient value, Variable object) tuples.
Representing model parameters as `Variable` objects, along with the use of
`tfe.implicit_gradients`, typically results in better encapsulation. For
example, the linear regression model described above can be written into a
class:
```python
class Model(object):
def __init__(self):
self.W = tfe.Variable(5., name='weight')
self.B = tfe.Variable(10., name='bias')
def predict(self, inputs):
return inputs * self.W + self.B
# The loss function to be optimized
def loss(model, inputs, targets):
error = model.predict(inputs) - targets
return tf.reduce_mean(tf.square(error))
# A toy dataset of points around 3 * x + 2
NUM_EXAMPLES = 1000
training_inputs = tf.random_normal([NUM_EXAMPLES])
noise = tf.random_normal([NUM_EXAMPLES])
training_outputs = training_inputs * 3 + 2 + noise
# Define:
# 1. A model
# 2. Derivatives of a loss function with respect to model parameters
# 3. A strategy for updating the variables based on the derivatives
model = Model()
grad = tfe.implicit_gradients(loss)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
# The training loop
print("Initial loss: %f" %
loss(model, training_inputs, training_outputs).numpy())
for i in range(201):
optimizer.apply_gradients(grad(model, training_inputs, training_outputs))
if i % 20 == 0:
print("Loss at step %d: %f" %
(i, loss(model, training_inputs, training_outputs).numpy()))
print("Final loss: %f" % loss(model, training_inputs, training_outputs).numpy())
print("W, B = %s, %s" % (model.W.numpy(), model.B.numpy()))
```
Output:
```
Initial loss: 69.693184
Loss at step 0: 66.987854
Loss at step 20: 30.553387
Loss at step 40: 14.250237
Loss at step 60: 6.955020
Loss at step 80: 3.690550
Loss at step 100: 2.229739
Loss at step 120: 1.576032
Loss at step 140: 1.283496
Loss at step 160: 1.152584
Loss at step 180: 1.093999
Final loss: 1.067780
W, B = 3.0114281, 2.0865183
```
Using `implicit_gradients` avoids the need to provide all the trainable
parameters of the model as arguments to the `loss` function.
### Using Keras and the Layers API
[Keras](https://keras.io) is a popular API for defining model structures. The
[`tf.keras.layers`](https://www.tensorflow.org/versions/master/api_docs/python/tf/keras/layers)
module provides a set of building blocks for models and is implemented using the
`tf.layers.Layer` subclasses in the
[`tf.layers`](https://www.tensorflow.org/versions/master/api_docs/python/tf/layers)
module. We encourage the use of these same building blocks when using
TensorFlow's eager execution feature. For example, the very same linear
regression model can be built using `tf.layers.Dense`:
```python
class Model(object):
def __init__(self):
self.layer = tf.layers.Dense(1)
def predict(self, inputs):
return self.layer(inputs)
```
The `tf.layers` API makes it more convenient to define more sophisticated
models. For example, the following will train an MNIST model:
```python
class MNISTModel(object):
def __init__(self, data_format):
# 'channels_first' is typically faster on GPUs
# while 'channels_last' is typically faster on CPUs.
# See: https://www.tensorflow.org/performance/performance_guide#data_formats
if data_format == 'channels_first':
self._input_shape = [-1, 1, 28, 28]
else:
self._input_shape = [-1, 28, 28, 1]
self.conv1 = tf.layers.Conv2D(32, 5,
padding='same',
activation=tf.nn.relu,
data_format=data_format)
self.max_pool2d = tf.layers.MaxPooling2D(
(2, 2), (2, 2), padding='same', data_format=data_format)
self.conv2 = tf.layers.Conv2D(64, 5,
padding='same',
activation=tf.nn.relu,
data_format=data_format)
self.dense1 = tf.layers.Dense(1024, activation=tf.nn.relu)
self.dropout = tf.layers.Dropout(0.5)
self.dense2 = tf.layers.Dense(10)
def predict(self, inputs):
x = tf.reshape(inputs, self._input_shape)
x = self.max_pool2d(self.conv1(x))
x = self.max_pool2d(self.conv2(x))
x = tf.layers.flatten(x)
x = self.dropout(self.dense1(x))
return self.dense2(x)
def loss(model, inputs, targets):
return tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
logits=model.predict(inputs), labels=targets))
# Load the training and validation data
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets("./mnist_data", one_hot=True)
# Train
device = "gpu:0" if tfe.num_gpus() else "cpu:0"
model = MNISTModel('channels_first' if tfe.num_gpus() else 'channels_last')
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
grad = tfe.implicit_gradients(loss)
for i in range(20001):
with tf.device(device):
(inputs, targets) = data.train.next_batch(50)
optimizer.apply_gradients(grad(model, inputs, targets))
if i % 100 == 0:
print("Step %d: Loss on training set : %f" %
(i, loss(model, inputs, targets).numpy()))
print("Loss on test set: %f" % loss(model, data.test.images, data.test.labels).numpy())
```
For a more complete example, see
[`tensorflow/contrib/eager/python/examples/mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py)
### Checkpointing trained variables
TensorFlow Variables (`tfe.Variable`) provides a way to represent shared,
persistent state of your model. The `tfe.Saver` class (which is a thin wrapper
over the
[`tf.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/train/Saver)
class) provides a means to save and restore variables to and from _checkpoints_.
For example:
```python
# Create variables.
x = tfe.Variable(10., name='x')
y = tfe.Variable(5., name='y')
# Create a Saver.
saver = tfe.Saver([x, y])
# Assign new values to the variables and save.
x.assign(2.)
saver.save('/tmp/ckpt')
# Change the variable after saving.
x.assign(11.)
assert 16. == (x + y).numpy() # 11 + 5
# Restore the values in the checkpoint.
saver.restore('/tmp/ckpt')
assert 7. == (x + y).numpy() # 2 + 5
```
### `tfe.Network`
You may often want to organize your models using classes, like the `MNISTModel`
class described above. We recommend inheriting from the `tfe.Network` class as
it provides conveniences like keeping track of all model variables and methods
to save and restore from checkpoints.
Sub-classes of `tfe.Network` may register `Layer`s (like classes in
[`tf.layers`](https://www.tensorflow.org/versions/master/api_docs/python/tf/layers),
or [Keras
layers](https://www.tensorflow.org/versions/master/api_docs/python/tf/keras/layers))
using a call to `self.track_layer()` and define the computation in an
implementation of `call()`.
Note that `tf.layers.Layer` objects (like `tf.layers.Dense`) create variables
lazily, when the first input is encountered.
For example, consider the following two-layer neural network:
```python
class TwoLayerNet(tfe.Network):
def __init__(self):
super(TwoLayerNet, self).__init__()
self.layer1 = self.track_layer(
tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False))
self.layer2 = self.track_layer(tf.layers.Dense(3, use_bias=False))
def call(self, x):
return self.layer2(self.layer1(x))
net = TwoLayerNet()
# No variables created yet
assert 0 == len(net.variables)
# They are created on first input:
inp = tf.constant([[1.]])
# Since input is a 1x1 matrix, net.l1 has 2 units and net.l2 has 3 units,
# the output is the product of a 1x1 matrix with a 1x2 matrix with a 2x3
# matrix.
assert [1, 3] == net(inp).shape.as_list() # Invoke net; get output shape.
assert 1 == len(net.layer1.variables)
assert 1 == len(net.layer2.variables)
assert 2 == len(net.variables) # weights for each layer.
assert [1, 2] == net.variables[0].shape.as_list() # weights of layer1.
assert [2, 3] == net.variables[1].shape.as_list() # weights of layer2.
```
The `tfe.Network` class is itself a sub-class of `tf.layers.Layer`. This allows
instances of `tfe.Network` to be embedded in other networks. For example:
```python
class ThreeLayerNet(tfe.Network):
def __init__(self):
super(ThreeLayerNet, self).__init__()
self.a = self.track_layer(TwoLayerNet())
self.b = self.track_layer(tf.layers.Dense(4, use_bias=False))
def call(self, x):
return self.b(self.a(x))
net = ThreeLayerNet()
assert [1, 4] == net(inp).shape.as_list()
assert 3 == len(net.variables)
assert [1, 2] == net.variables[0].shape.as_list()
assert [2, 3] == net.variables[1].shape.as_list()
assert [3, 4] == net.variables[2].shape.as_list()
```
See more examples in
[`tensorflow/contrib/eager/python/examples`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples).
`tfe.Saver` in combination with `tfe.restore_variables_on_create` provides a
convenient way to save and load checkpoints without changing the program once
the checkpoint has been created. For example, we can set an objective for the
output of our network, choose an optimizer, and a location for the checkpoint:
```python
objective = tf.constant([[2., 3., 4., 5.]])
optimizer = tf.train.AdamOptimizer(0.01)
checkpoint_directory = '/tmp/tfe_example'
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
net = ThreeLayerNet()
```
Note that variables have not been created yet. We want them to be restored from
a checkpoint, if one exists, so we create them inside a
`tfe.restore_variables_on_create` context manager. Then our training loop is the
same whether starting training or resuming from a previous checkpoint:
```python
with tfe.restore_variables_on_create(
tf.train.latest_checkpoint(checkpoint_directory)):
global_step = tf.train.get_or_create_global_step()
for _ in range(100):
loss_fn = lambda: tf.norm(net(inp) - objective)
optimizer.minimize(loss_fn, global_step=global_step)
if tf.equal(global_step % 20, 0):
print("Step %d, output %s" % (global_step.numpy(),
net(inp).numpy()))
all_variables = (
net.variables
+ tfe.get_optimizer_variables(optimizer)
+ [global_step])
# Save the checkpoint.
tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step)
```
The first time it runs, `Network` variables are initialized randomly. Then the
output is trained to match the objective we've set:
```
Step 20, output [[ 0.03575622 0.29863232 0.03474367 0.24735749]]
Step 40, output [[ 0.40646029 0.9856872 0.46851286 0.95358551]]
Step 60, output [[ 1.74541104 2.800704 1.79055595 2.74783421]]
Step 80, output [[ 2.14977384 3.44340849 3.96120024 5.16242075]]
Step 100, output [[ 1.99943113 3.02364397 3.93500996 4.9610076 ]]
```
In subsequent iterations, variables are initialized with the values read from
the latest checkpoint. Running the same code again, we continue from where we
left off:
```
Step 120, output [[ 1.99234128 3.0271616 3.98732996 4.96401167]]
Step 140, output [[ 2.00133467 3.01270437 4.00616646 5.00406504]]
Step 160, output [[ 1.99647415 2.9956708 3.99064088 4.99632359]]
Step 180, output [[ 2.00699997 3.00904822 4.00706148 5.01193142]]
Step 200, output [[ 1.98334622 2.98249531 3.97375059 4.97123432]]
```
### Summaries, metrics and TensorBoard
[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
is a popular tool for understanding, debugging and optimizing the model training
process. To benefit from the visualizations offered by TensorBoard, summary
events need to be written during the course of execution of your program. You
might find many Tensorflow programs that include the
[`tf.summary`](https://www.tensorflow.org/api_guides/python/summary) operations
during graph construction.
`tf.summary` operations are *not* compatible with eager execution, but an
equivalent alternative exists in
[`tf.contrib.summary`](https://www.tensorflow.org/versions/master/api_guides/python/tf/contrib/summary/)
that is compatible with both eager execution and graph construction.
During model construction simply insert summary operations like
`tf.contrib.summary.scalar`. These operations do nothing by default, unless a
summary writer is currently active and a writing policy is set.
For example, to record summaries once every 100 global steps, use:
```python
tf.train.get_or_create_global_step() # Ensuring the global step variable exists
writer = tf.contrib.summary.create_summary_file_writer(logdir)
for _ in range(iterations):
with writer.as_default():
with tf.contrib.summary.record_summaries_every_n_global_steps(100):
# your model code goes here
tf.contrib.summary.scalar('loss', loss)
# ...
```
See the full mnist example in
[`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist)
for a full model using `tf.contrib.summary`.
Similarly to summaries, the metrics in `tf.metrics` are currently not compatible
with eager execution. We instead provide object-oriented metrics in the
`tfe.metrics` package, which are compatible with graph construction as well.
Metrics in the `tfe.metrics`, such as `tfe.metrics.Mean` and
`tfe.Metrics.Accuracy`, all implement an intuitive object-oriented
interface. Here's an example of how to use the `tfe.metrics.Mean` metric:
```python
# Metrics are objects, which can be created and destroyed.
my_mean = tfe.metrics.Mean(name='my_mean')
# While a metric is active, you can call it as a function to accumulate into its
# internal state.
my_mean(0.0)
my_mean(10.0)
# Once you've finished updating the metric, you can get its result. In this case
# a simple average over all the calls to it. If a summary writer is active the
# metric will write the appropriate summaries using the metric name.
assert 5.0 == my_mean.result().numpy()
```
For a full example of a model using metrics for evaluation, see the mnist
example in
[`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist).
### Input Pipelines
The discussion above has been centered around the computation executed by your
model. The
[`tf.data`](https://www.tensorflow.org/versions/master/api_docs/python/tf/data)
module provides APIs to build complex input pipelines from simple, reusable
pieces.
If you're familiar with constructing `tf.data.Dataset` objects when building
TensorFlow graphs, the same API calls are used when eager execution is enabled.
However, the process of iterating over elements of the dataset differs between
eager execution and graph construction. When eager execution is enabled, the
discussion on iterator creation using `make_one_shot_iterator()` and
`get_next()` in the
[Programmer's
Guide](https://www.tensorflow.org/versions/master/programmers_guide/datasets) is
*not* applicable. Instead, a more Pythonic `Iterator` class is available.
For example:
```python
# Create a source Dataset from in-memory numpy arrays.
# For reading from files on disk, you may want to use other Dataset classes
# like the TextLineDataset or the TFRecordDataset.
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])
# Apply transformations, shuffling, batching etc.
dataset = dataset.map(tf.square).shuffle(2).batch(2)
# Use tfe.Iterator to iterate over the dataset.
for x in tfe.Iterator(dataset):
print(x)
```
Output:
```
tf.Tensor([4 9], shape=(2,), dtype=int32)
tf.Tensor([16 25], shape=(2,), dtype=int32)
tf.Tensor([36 1], shape=(2,), dtype=int32)
```
## Interoperating with Graphs
Eager execution improves the process of model development in Python; however,
because it is in its earliest stages, it does not yet support some features
available to [TensorFlow
graphs](https://www.tensorflow.org/get_started/get_started#the_computational_graph)
that are desirable when deploying models in production. In particular, eager
execution does not yet support distributed training, exporting models (to other
[programming languages](https://www.tensorflow.org/api_docs/), [TensorFlow
serving](https://www.tensorflow.org/serving/), and mobile applications), and
various memory and computation optimizations that are applied to TensorFlow's
dataflow graphs.
That said, the APIs used to build modes are exactly the same whether executing
eagerly or constructing graphs. This means that you can iteratively develop your
model with eager execution enabled and later, if needed, use the same code to
reap the benefits of representing models as computational graphs.
For example,
[`mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py)
defines a model that is eagerly executed. That same code is used to construct
and execute a graph in
[`mnist_graph_test.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py).
Other models in the [examples
directory](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/)
demonstrate this as well.
Some differences worth noting:
- There is no notion of a `tf.placeholder` or a `tf.Session` when eager
execution is enabled.
- Many properties on the `tf.Tensor` object, like `tf.Tensor.name`,
`tf.Tensor.op`, `tf.Tensor.inputs` are not meaningful when eager execution
is enabled and their use will raise an `AttributeError`.
- To use `tfe.implicit_gradients` in graph construction, variables must be
created with [`use_resource=True`] provided to
[`tf.get_variable()`](https://www.tensorflow.org/api_docs/python/tf/get_variable)
or
[`tf.variable_scope()`](https://www.tensorflow.org/api_docs/python/tf/variable_scope).
- Some API calls (such as the functional-style `tf.layers.dense`,
`tf.layers.conv2d`) are not compatible with eager execution. Use of such
methods should raise an error indicating the alternative (e.g., the
`tf.layers.Dense` and `tf.layers.Conv2D` classes).
## What next?
Please give eager execution a spin. This feature is in early stages and is
evolving, so we welcome your feedback via issues on GitHub (see [known
issues](https://github.com/tensorflow/tensorflow/labels/eager)).
You may want to browse through some sample code, including benchmarks for some:
- [Linear Regression](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/linear_regression)
- [MNIST handwritten digit classifier](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist)
- [ResNet50 image classification](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/resnet50)
- [RNN to generate colors](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_colorbot)
- [RNN language model](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_ptb)

View File

@ -153,6 +153,7 @@ sh_binary(
"//tensorflow:tensorflow_py", "//tensorflow:tensorflow_py",
"//tensorflow/contrib/boosted_trees:boosted_trees_pip", "//tensorflow/contrib/boosted_trees:boosted_trees_pip",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/eager/python/examples:examples_pip",
"//tensorflow/contrib/gan:gan", "//tensorflow/contrib/gan:gan",
"//tensorflow/contrib/graph_editor:graph_editor_pip", "//tensorflow/contrib/graph_editor:graph_editor_pip",
"//tensorflow/contrib/keras:keras", "//tensorflow/contrib/keras:keras",