K-FAC: Multi-tower ConvNet example.

PiperOrigin-RevId: 173982527
This commit is contained in:
A. Unique TensorFlower 2017-10-30 18:44:06 -07:00 committed by TensorFlower Gardener
parent 2ba5298565
commit bb7ed1c889
3 changed files with 153 additions and 79 deletions

View File

@ -83,7 +83,7 @@ def conv_layer(layer_id, inputs, kernel_size, out_channels):
activations = tf.nn.relu(preactivations)
# layer.weights is a list. This converts it a (hashable) tuple.
return preactivations, activations, tuple(layer.weights)
return preactivations, activations, (layer.kernel, layer.bias)
def max_pool_layer(layer_id, inputs, kernel_size, stride):
@ -128,7 +128,7 @@ def linear_layer(layer_id, inputs, output_size):
return pre, params
def build_model(examples, labels, num_labels, num_ps_tasks=0):
def build_model(examples, labels, num_labels, layer_collection):
"""Builds a ConvNet classification model.
Args:
@ -137,65 +137,64 @@ def build_model(examples, labels, num_labels, num_ps_tasks=0):
labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
by softmax for each example.
num_labels: int. Number of distinct values 'labels' can take on.
num_ps_tasks: int. Number of parameter servers. If zero, variables
will be placed locally.
layer_collection: LayerCollection instance. Layers will be registered here.
Returns:
loss: 0-D Tensor representing loss to be minimized.
statistics: dict mapping strings to Tensors. Additional model evaluation
statistics.
layer_collection: LayerCollection instance describing model architecture.
accuracy: 0-D Tensor representing model's accuracy.
"""
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
# Build a ConvNet. For each layer with parameters, we'll keep track of the
# preactivations, activations, weights, and bias.
tf.logging.info("Building model.")
pre0, act0, params0 = conv_layer(
layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
pre2, act2, params2 = conv_layer(
layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
logits, params4 = linear_layer(
layer_id=4, inputs=flat_act3, output_size=num_labels)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits))
accuracy = tf.reduce_mean(
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
# Build a ConvNet. For each layer with parameters, we'll keep track of the
# preactivations, activations, weights, and bias.
tf.logging.info("Building model.")
pre0, act0, params0 = conv_layer(
layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
pre2, act2, params2 = conv_layer(
layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
logits, params4 = linear_layer(
layer_id=4, inputs=flat_act3, output_size=num_labels)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits))
accuracy = tf.reduce_mean(
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
tf.summary.scalar("loss", loss)
tf.summary.scalar("accuracy", accuracy)
tf.summary.scalar("loss", loss)
tf.summary.scalar("accuracy", accuracy)
# Register parameters. K-FAC needs to know about the inputs, outputs, and
# parameters of each conv/fully connected layer and the logits powering the
# posterior probability over classes.
tf.logging.info("Building KFAC Optimizer.")
layer_collection = lc.LayerCollection()
layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
pre0)
layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
layer_collection.register_fully_connected(params4, flat_act3, logits)
layer_collection.register_categorical_predictive_distribution(logits)
# Register parameters. K-FAC needs to know about the inputs, outputs, and
# parameters of each conv/fully connected layer and the logits powering the
# posterior probability over classes.
tf.logging.info("Building LayerCollection.")
layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
pre0)
layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
layer_collection.register_fully_connected(params4, flat_act3, logits)
layer_collection.register_categorical_predictive_distribution(
logits, name="logits")
return loss, {"accuracy": accuracy}, layer_collection
return loss, accuracy
def minimize_loss_single_machine(loss, statistics, layer_collection):
def minimize_loss_single_machine(loss,
accuracy,
layer_collection,
session_config=None):
"""Minimize loss with K-FAC on a single machine.
A single Session is responsible for running all of K-FAC's ops.
Args:
loss: 0-D Tensor. Loss to be minimized.
statistics: dict mapping strings to 0-D Tensors. Additional statistics to
run with each step.
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
session_config: None or tf.ConfigProto. Configuration for tf.Session().
Returns:
final value for 'statistics'.
final value for 'accuracy'.
"""
# Train with K-FAC.
global_step = tf.train.get_or_create_global_step()
@ -208,19 +207,19 @@ def minimize_loss_single_machine(loss, statistics, layer_collection):
train_op = optimizer.minimize(loss, global_step=global_step)
tf.logging.info("Starting training.")
with tf.train.MonitoredTrainingSession() as sess:
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
while not sess.should_stop():
global_step_, loss_, statistics_, _, _ = sess.run(
[global_step, loss, statistics, train_op, optimizer.cov_update_op])
global_step_, loss_, accuracy_, _, _ = sess.run(
[global_step, loss, accuracy, train_op, optimizer.cov_update_op])
if global_step_ % 100 == 0:
sess.run(optimizer.inv_update_op)
if global_step_ % 100 == 0:
tf.logging.info("global_step: %d | loss: %f | %s", global_step_, loss_,
statistics_)
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
global_step_, loss_, accuracy_)
return statistics_
return accuracy_
def _is_gradient_task(task_id, num_tasks):
@ -252,8 +251,7 @@ def _num_gradient_tasks(num_tasks):
def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
checkpoint_dir, loss, statistics,
layer_collection):
checkpoint_dir, loss, accuracy, layer_collection):
"""Minimize loss with an synchronous implementation of K-FAC.
Different tasks are responsible for different parts of K-FAC's Ops. The first
@ -269,13 +267,13 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
string to run locally.
checkpoint_dir: string or None. Path to store checkpoints under.
loss: 0-D Tensor. Loss to be minimized.
statistics: dict mapping strings to 0-D Tensors. Additional statistics to
accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
run with each step.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
Returns:
final value for 'statistics'.
final value for 'accuracy'.
Raises:
ValueError: if task_id >= num_worker_tasks.
@ -318,12 +316,12 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
else:
raise ValueError("Which op should task %d do?" % task_id)
global_step_, loss_, statistics_, _ = sess.run(
[global_step, loss, statistics, learning_op])
tf.logging.info("global_step: %d | loss: %f | %s", global_step_, loss_,
statistics_)
global_step_, loss_, accuracy_, _ = sess.run(
[global_step, loss, accuracy, learning_op])
tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
loss_, accuracy_)
return statistics_
return accuracy_
def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
@ -347,11 +345,69 @@ def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
flatten_images=False)
# Build a ConvNet.
loss, statistics, layer_collection = build_model(
examples, labels, num_labels=10)
layer_collection = lc.LayerCollection()
loss, accuracy = build_model(
examples, labels, num_labels=10, layer_collection=layer_collection)
# Fit model.
return minimize_loss_single_machine(loss, statistics, layer_collection)
return minimize_loss_single_machine(loss, accuracy, layer_collection)
def train_mnist_multitower(data_dir, num_epochs, num_towers,
use_fake_data=True):
"""Train a ConvNet on MNIST.
Args:
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
num_towers: int. Number of CPUs to split inference across.
use_fake_data: bool. If True, generate a synthetic dataset.
Returns:
accuracy of model on the final minibatch of training data.
"""
# Load a dataset.
tf.logging.info("Loading MNIST into memory.")
tower_batch_size = 128
batch_size = tower_batch_size * num_towers
tf.logging.info(
("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
"tower batch size.") % (batch_size, num_towers, tower_batch_size))
examples, labels = mnist.load_mnist(
data_dir,
num_epochs=num_epochs,
batch_size=batch_size,
use_fake_data=use_fake_data,
flatten_images=False)
# Split minibatch across towers.
examples = tf.split(examples, num_towers)
labels = tf.split(labels, num_towers)
# Build an MLP. Each tower's layers will be added to the LayerCollection.
layer_collection = lc.LayerCollection()
tower_results = []
for tower_id in range(num_towers):
with tf.device("/cpu:%d" % tower_id):
with tf.name_scope("tower%d" % tower_id):
with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
tf.logging.info("Building tower %d." % tower_id)
tower_results.append(
build_model(examples[tower_id], labels[tower_id], 10,
layer_collection))
losses, accuracies = zip(*tower_results)
# Average across towers.
loss = tf.reduce_mean(losses)
accuracy = tf.reduce_mean(accuracies)
# Fit model.
session_config = tf.ConfigProto(
allow_soft_placement=False, device_count={
"CPU": num_towers
})
return minimize_loss_single_machine(
loss, accuracy, layer_collection, session_config=session_config)
def train_mnist_distributed(task_id,
@ -385,13 +441,15 @@ def train_mnist_distributed(task_id,
flatten_images=False)
# Build a ConvNet.
loss, statistics, layer_collection = build_model(
examples, labels, num_labels=10, num_ps_tasks=num_ps_tasks)
layer_collection = lc.LayerCollection()
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
loss, accuracy = build_model(
examples, labels, num_labels=10, layer_collection=layer_collection)
# Fit model.
checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks,
master, checkpoint_dir, loss, statistics,
master, checkpoint_dir, loss, accuracy,
layer_collection)

View File

@ -33,7 +33,12 @@ FLAGS = None
def main(argv):
_ = argv
convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
if FLAGS.num_towers > 1:
convnet.train_mnist_multitower(
FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
else:
convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
if __name__ == "__main__":
@ -43,5 +48,10 @@ if __name__ == "__main__":
type=str,
default="/tmp/mnist",
help="Directory to store dataset in.")
parser.add_argument(
"--num_towers",
type=int,
default=1,
help="Number of CPUs to split minibatch across.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -66,8 +66,9 @@ class ConvNetTest(tf.test.TestCase):
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, [None, 6, 6, 3])
y = tf.placeholder(tf.int64, [None])
loss, statistics, layer_collection = convnet.build_model(
x, y, num_labels=5)
layer_collection = lc.LayerCollection()
loss, accuracy = convnet.build_model(
x, y, num_labels=5, layer_collection=layer_collection)
# Ensure layers and logits were registered.
self.assertEqual(len(layer_collection.fisher_blocks), 3)
@ -80,7 +81,7 @@ class ConvNetTest(tf.test.TestCase):
x: np.random.randn(10, 6, 6, 3).astype(np.float32),
y: np.random.randint(5, size=10).astype(np.int64),
}
sess.run([loss, statistics], feed_dict=feed_dict)
sess.run([loss, accuracy], feed_dict=feed_dict)
def _build_toy_problem(self):
"""Construct a toy linear regression problem.
@ -90,8 +91,7 @@ class ConvNetTest(tf.test.TestCase):
Returns:
loss: 0-D Tensor representing loss to be minimized.
statistics: dict mapping strings to Tensors. Additional model evaluation
statistics.
accuracy: 0-D Tensors representing model accuracy.
layer_collection: LayerCollection instance describing model architecture.
"""
x = np.asarray([[1.], [2.]]).astype(np.float32)
@ -101,34 +101,34 @@ class ConvNetTest(tf.test.TestCase):
w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer())
y_hat = tf.matmul(x, w)
loss = tf.reduce_mean(0.5 * tf.square(y_hat - y))
statistics = {"loss": loss}
accuracy = loss
layer_collection = lc.LayerCollection()
layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat)
layer_collection.register_normal_predictive_distribution(y_hat)
return loss, statistics, layer_collection
return loss, accuracy, layer_collection
def testMinimizeLossSingleMachine(self):
with tf.Graph().as_default():
loss, statistics, layer_collection = self._build_toy_problem()
statistics_ = convnet.minimize_loss_single_machine(
loss, statistics, layer_collection)
self.assertLess(statistics_["loss"], 1.0)
loss, accuracy, layer_collection = self._build_toy_problem()
accuracy_ = convnet.minimize_loss_single_machine(loss, accuracy,
layer_collection)
self.assertLess(accuracy_, 1.0)
def testMinimizeLossDistributed(self):
with tf.Graph().as_default():
loss, statistics, layer_collection = self._build_toy_problem()
statistics_ = convnet.minimize_loss_distributed(
loss, accuracy, layer_collection = self._build_toy_problem()
accuracy_ = convnet.minimize_loss_distributed(
task_id=0,
num_worker_tasks=1,
num_ps_tasks=0,
master="",
checkpoint_dir=None,
loss=loss,
statistics=statistics,
accuracy=accuracy,
layer_collection=layer_collection)
self.assertLess(statistics_["loss"], 1.0)
self.assertLess(accuracy_, 1.0)
def testTrainMnistSingleMachine(self):
with tf.Graph().as_default():
@ -140,6 +140,12 @@ class ConvNetTest(tf.test.TestCase):
convnet.train_mnist_single_machine(
data_dir=None, num_epochs=1, use_fake_data=True)
def testTrainMnistMultitower(self):
with tf.Graph().as_default():
# Ensure model training doesn't crash.
convnet.train_mnist_multitower(
data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
def testTrainMnistDistributed(self):
with tf.Graph().as_default():
# Ensure model training doesn't crash.