mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
K-FAC: Multi-tower ConvNet example.
PiperOrigin-RevId: 173982527
This commit is contained in:
parent
2ba5298565
commit
bb7ed1c889
|
|
@ -83,7 +83,7 @@ def conv_layer(layer_id, inputs, kernel_size, out_channels):
|
|||
activations = tf.nn.relu(preactivations)
|
||||
|
||||
# layer.weights is a list. This converts it a (hashable) tuple.
|
||||
return preactivations, activations, tuple(layer.weights)
|
||||
return preactivations, activations, (layer.kernel, layer.bias)
|
||||
|
||||
|
||||
def max_pool_layer(layer_id, inputs, kernel_size, stride):
|
||||
|
|
@ -128,7 +128,7 @@ def linear_layer(layer_id, inputs, output_size):
|
|||
return pre, params
|
||||
|
||||
|
||||
def build_model(examples, labels, num_labels, num_ps_tasks=0):
|
||||
def build_model(examples, labels, num_labels, layer_collection):
|
||||
"""Builds a ConvNet classification model.
|
||||
|
||||
Args:
|
||||
|
|
@ -137,65 +137,64 @@ def build_model(examples, labels, num_labels, num_ps_tasks=0):
|
|||
labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
|
||||
by softmax for each example.
|
||||
num_labels: int. Number of distinct values 'labels' can take on.
|
||||
num_ps_tasks: int. Number of parameter servers. If zero, variables
|
||||
will be placed locally.
|
||||
layer_collection: LayerCollection instance. Layers will be registered here.
|
||||
|
||||
Returns:
|
||||
loss: 0-D Tensor representing loss to be minimized.
|
||||
statistics: dict mapping strings to Tensors. Additional model evaluation
|
||||
statistics.
|
||||
layer_collection: LayerCollection instance describing model architecture.
|
||||
accuracy: 0-D Tensor representing model's accuracy.
|
||||
"""
|
||||
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
|
||||
# Build a ConvNet. For each layer with parameters, we'll keep track of the
|
||||
# preactivations, activations, weights, and bias.
|
||||
tf.logging.info("Building model.")
|
||||
pre0, act0, params0 = conv_layer(
|
||||
layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
|
||||
act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
|
||||
pre2, act2, params2 = conv_layer(
|
||||
layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
|
||||
act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
|
||||
flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
|
||||
logits, params4 = linear_layer(
|
||||
layer_id=4, inputs=flat_act3, output_size=num_labels)
|
||||
loss = tf.reduce_mean(
|
||||
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits))
|
||||
accuracy = tf.reduce_mean(
|
||||
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
|
||||
# Build a ConvNet. For each layer with parameters, we'll keep track of the
|
||||
# preactivations, activations, weights, and bias.
|
||||
tf.logging.info("Building model.")
|
||||
pre0, act0, params0 = conv_layer(
|
||||
layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
|
||||
act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
|
||||
pre2, act2, params2 = conv_layer(
|
||||
layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
|
||||
act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
|
||||
flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
|
||||
logits, params4 = linear_layer(
|
||||
layer_id=4, inputs=flat_act3, output_size=num_labels)
|
||||
loss = tf.reduce_mean(
|
||||
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits))
|
||||
accuracy = tf.reduce_mean(
|
||||
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
|
||||
|
||||
tf.summary.scalar("loss", loss)
|
||||
tf.summary.scalar("accuracy", accuracy)
|
||||
tf.summary.scalar("loss", loss)
|
||||
tf.summary.scalar("accuracy", accuracy)
|
||||
|
||||
# Register parameters. K-FAC needs to know about the inputs, outputs, and
|
||||
# parameters of each conv/fully connected layer and the logits powering the
|
||||
# posterior probability over classes.
|
||||
tf.logging.info("Building KFAC Optimizer.")
|
||||
layer_collection = lc.LayerCollection()
|
||||
layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
|
||||
pre0)
|
||||
layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
|
||||
layer_collection.register_fully_connected(params4, flat_act3, logits)
|
||||
layer_collection.register_categorical_predictive_distribution(logits)
|
||||
# Register parameters. K-FAC needs to know about the inputs, outputs, and
|
||||
# parameters of each conv/fully connected layer and the logits powering the
|
||||
# posterior probability over classes.
|
||||
tf.logging.info("Building LayerCollection.")
|
||||
layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
|
||||
pre0)
|
||||
layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
|
||||
layer_collection.register_fully_connected(params4, flat_act3, logits)
|
||||
layer_collection.register_categorical_predictive_distribution(
|
||||
logits, name="logits")
|
||||
|
||||
return loss, {"accuracy": accuracy}, layer_collection
|
||||
return loss, accuracy
|
||||
|
||||
|
||||
def minimize_loss_single_machine(loss, statistics, layer_collection):
|
||||
def minimize_loss_single_machine(loss,
|
||||
accuracy,
|
||||
layer_collection,
|
||||
session_config=None):
|
||||
"""Minimize loss with K-FAC on a single machine.
|
||||
|
||||
A single Session is responsible for running all of K-FAC's ops.
|
||||
|
||||
Args:
|
||||
loss: 0-D Tensor. Loss to be minimized.
|
||||
statistics: dict mapping strings to 0-D Tensors. Additional statistics to
|
||||
run with each step.
|
||||
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
|
||||
layer_collection: LayerCollection instance describing model architecture.
|
||||
Used by K-FAC to construct preconditioner.
|
||||
session_config: None or tf.ConfigProto. Configuration for tf.Session().
|
||||
|
||||
Returns:
|
||||
final value for 'statistics'.
|
||||
final value for 'accuracy'.
|
||||
"""
|
||||
# Train with K-FAC.
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
|
|
@ -208,19 +207,19 @@ def minimize_loss_single_machine(loss, statistics, layer_collection):
|
|||
train_op = optimizer.minimize(loss, global_step=global_step)
|
||||
|
||||
tf.logging.info("Starting training.")
|
||||
with tf.train.MonitoredTrainingSession() as sess:
|
||||
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
|
||||
while not sess.should_stop():
|
||||
global_step_, loss_, statistics_, _, _ = sess.run(
|
||||
[global_step, loss, statistics, train_op, optimizer.cov_update_op])
|
||||
global_step_, loss_, accuracy_, _, _ = sess.run(
|
||||
[global_step, loss, accuracy, train_op, optimizer.cov_update_op])
|
||||
|
||||
if global_step_ % 100 == 0:
|
||||
sess.run(optimizer.inv_update_op)
|
||||
|
||||
if global_step_ % 100 == 0:
|
||||
tf.logging.info("global_step: %d | loss: %f | %s", global_step_, loss_,
|
||||
statistics_)
|
||||
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
|
||||
global_step_, loss_, accuracy_)
|
||||
|
||||
return statistics_
|
||||
return accuracy_
|
||||
|
||||
|
||||
def _is_gradient_task(task_id, num_tasks):
|
||||
|
|
@ -252,8 +251,7 @@ def _num_gradient_tasks(num_tasks):
|
|||
|
||||
|
||||
def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
|
||||
checkpoint_dir, loss, statistics,
|
||||
layer_collection):
|
||||
checkpoint_dir, loss, accuracy, layer_collection):
|
||||
"""Minimize loss with an synchronous implementation of K-FAC.
|
||||
|
||||
Different tasks are responsible for different parts of K-FAC's Ops. The first
|
||||
|
|
@ -269,13 +267,13 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
|
|||
string to run locally.
|
||||
checkpoint_dir: string or None. Path to store checkpoints under.
|
||||
loss: 0-D Tensor. Loss to be minimized.
|
||||
statistics: dict mapping strings to 0-D Tensors. Additional statistics to
|
||||
accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
|
||||
run with each step.
|
||||
layer_collection: LayerCollection instance describing model architecture.
|
||||
Used by K-FAC to construct preconditioner.
|
||||
|
||||
Returns:
|
||||
final value for 'statistics'.
|
||||
final value for 'accuracy'.
|
||||
|
||||
Raises:
|
||||
ValueError: if task_id >= num_worker_tasks.
|
||||
|
|
@ -318,12 +316,12 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
|
|||
else:
|
||||
raise ValueError("Which op should task %d do?" % task_id)
|
||||
|
||||
global_step_, loss_, statistics_, _ = sess.run(
|
||||
[global_step, loss, statistics, learning_op])
|
||||
tf.logging.info("global_step: %d | loss: %f | %s", global_step_, loss_,
|
||||
statistics_)
|
||||
global_step_, loss_, accuracy_, _ = sess.run(
|
||||
[global_step, loss, accuracy, learning_op])
|
||||
tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
|
||||
loss_, accuracy_)
|
||||
|
||||
return statistics_
|
||||
return accuracy_
|
||||
|
||||
|
||||
def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
|
||||
|
|
@ -347,11 +345,69 @@ def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
|
|||
flatten_images=False)
|
||||
|
||||
# Build a ConvNet.
|
||||
loss, statistics, layer_collection = build_model(
|
||||
examples, labels, num_labels=10)
|
||||
layer_collection = lc.LayerCollection()
|
||||
loss, accuracy = build_model(
|
||||
examples, labels, num_labels=10, layer_collection=layer_collection)
|
||||
|
||||
# Fit model.
|
||||
return minimize_loss_single_machine(loss, statistics, layer_collection)
|
||||
return minimize_loss_single_machine(loss, accuracy, layer_collection)
|
||||
|
||||
|
||||
def train_mnist_multitower(data_dir, num_epochs, num_towers,
|
||||
use_fake_data=True):
|
||||
"""Train a ConvNet on MNIST.
|
||||
|
||||
Args:
|
||||
data_dir: string. Directory to read MNIST examples from.
|
||||
num_epochs: int. Number of passes to make over the training set.
|
||||
num_towers: int. Number of CPUs to split inference across.
|
||||
use_fake_data: bool. If True, generate a synthetic dataset.
|
||||
|
||||
Returns:
|
||||
accuracy of model on the final minibatch of training data.
|
||||
"""
|
||||
# Load a dataset.
|
||||
tf.logging.info("Loading MNIST into memory.")
|
||||
tower_batch_size = 128
|
||||
batch_size = tower_batch_size * num_towers
|
||||
tf.logging.info(
|
||||
("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
|
||||
"tower batch size.") % (batch_size, num_towers, tower_batch_size))
|
||||
examples, labels = mnist.load_mnist(
|
||||
data_dir,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
use_fake_data=use_fake_data,
|
||||
flatten_images=False)
|
||||
|
||||
# Split minibatch across towers.
|
||||
examples = tf.split(examples, num_towers)
|
||||
labels = tf.split(labels, num_towers)
|
||||
|
||||
# Build an MLP. Each tower's layers will be added to the LayerCollection.
|
||||
layer_collection = lc.LayerCollection()
|
||||
tower_results = []
|
||||
for tower_id in range(num_towers):
|
||||
with tf.device("/cpu:%d" % tower_id):
|
||||
with tf.name_scope("tower%d" % tower_id):
|
||||
with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
|
||||
tf.logging.info("Building tower %d." % tower_id)
|
||||
tower_results.append(
|
||||
build_model(examples[tower_id], labels[tower_id], 10,
|
||||
layer_collection))
|
||||
losses, accuracies = zip(*tower_results)
|
||||
|
||||
# Average across towers.
|
||||
loss = tf.reduce_mean(losses)
|
||||
accuracy = tf.reduce_mean(accuracies)
|
||||
|
||||
# Fit model.
|
||||
session_config = tf.ConfigProto(
|
||||
allow_soft_placement=False, device_count={
|
||||
"CPU": num_towers
|
||||
})
|
||||
return minimize_loss_single_machine(
|
||||
loss, accuracy, layer_collection, session_config=session_config)
|
||||
|
||||
|
||||
def train_mnist_distributed(task_id,
|
||||
|
|
@ -385,13 +441,15 @@ def train_mnist_distributed(task_id,
|
|||
flatten_images=False)
|
||||
|
||||
# Build a ConvNet.
|
||||
loss, statistics, layer_collection = build_model(
|
||||
examples, labels, num_labels=10, num_ps_tasks=num_ps_tasks)
|
||||
layer_collection = lc.LayerCollection()
|
||||
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
|
||||
loss, accuracy = build_model(
|
||||
examples, labels, num_labels=10, layer_collection=layer_collection)
|
||||
|
||||
# Fit model.
|
||||
checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
|
||||
return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks,
|
||||
master, checkpoint_dir, loss, statistics,
|
||||
master, checkpoint_dir, loss, accuracy,
|
||||
layer_collection)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,12 @@ FLAGS = None
|
|||
|
||||
def main(argv):
|
||||
_ = argv
|
||||
convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
|
||||
|
||||
if FLAGS.num_towers > 1:
|
||||
convnet.train_mnist_multitower(
|
||||
FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
|
||||
else:
|
||||
convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -43,5 +48,10 @@ if __name__ == "__main__":
|
|||
type=str,
|
||||
default="/tmp/mnist",
|
||||
help="Directory to store dataset in.")
|
||||
parser.add_argument(
|
||||
"--num_towers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of CPUs to split minibatch across.")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
|
|
|||
|
|
@ -66,8 +66,9 @@ class ConvNetTest(tf.test.TestCase):
|
|||
with tf.Graph().as_default():
|
||||
x = tf.placeholder(tf.float32, [None, 6, 6, 3])
|
||||
y = tf.placeholder(tf.int64, [None])
|
||||
loss, statistics, layer_collection = convnet.build_model(
|
||||
x, y, num_labels=5)
|
||||
layer_collection = lc.LayerCollection()
|
||||
loss, accuracy = convnet.build_model(
|
||||
x, y, num_labels=5, layer_collection=layer_collection)
|
||||
|
||||
# Ensure layers and logits were registered.
|
||||
self.assertEqual(len(layer_collection.fisher_blocks), 3)
|
||||
|
|
@ -80,7 +81,7 @@ class ConvNetTest(tf.test.TestCase):
|
|||
x: np.random.randn(10, 6, 6, 3).astype(np.float32),
|
||||
y: np.random.randint(5, size=10).astype(np.int64),
|
||||
}
|
||||
sess.run([loss, statistics], feed_dict=feed_dict)
|
||||
sess.run([loss, accuracy], feed_dict=feed_dict)
|
||||
|
||||
def _build_toy_problem(self):
|
||||
"""Construct a toy linear regression problem.
|
||||
|
|
@ -90,8 +91,7 @@ class ConvNetTest(tf.test.TestCase):
|
|||
|
||||
Returns:
|
||||
loss: 0-D Tensor representing loss to be minimized.
|
||||
statistics: dict mapping strings to Tensors. Additional model evaluation
|
||||
statistics.
|
||||
accuracy: 0-D Tensors representing model accuracy.
|
||||
layer_collection: LayerCollection instance describing model architecture.
|
||||
"""
|
||||
x = np.asarray([[1.], [2.]]).astype(np.float32)
|
||||
|
|
@ -101,34 +101,34 @@ class ConvNetTest(tf.test.TestCase):
|
|||
w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer())
|
||||
y_hat = tf.matmul(x, w)
|
||||
loss = tf.reduce_mean(0.5 * tf.square(y_hat - y))
|
||||
statistics = {"loss": loss}
|
||||
accuracy = loss
|
||||
|
||||
layer_collection = lc.LayerCollection()
|
||||
layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat)
|
||||
layer_collection.register_normal_predictive_distribution(y_hat)
|
||||
|
||||
return loss, statistics, layer_collection
|
||||
return loss, accuracy, layer_collection
|
||||
|
||||
def testMinimizeLossSingleMachine(self):
|
||||
with tf.Graph().as_default():
|
||||
loss, statistics, layer_collection = self._build_toy_problem()
|
||||
statistics_ = convnet.minimize_loss_single_machine(
|
||||
loss, statistics, layer_collection)
|
||||
self.assertLess(statistics_["loss"], 1.0)
|
||||
loss, accuracy, layer_collection = self._build_toy_problem()
|
||||
accuracy_ = convnet.minimize_loss_single_machine(loss, accuracy,
|
||||
layer_collection)
|
||||
self.assertLess(accuracy_, 1.0)
|
||||
|
||||
def testMinimizeLossDistributed(self):
|
||||
with tf.Graph().as_default():
|
||||
loss, statistics, layer_collection = self._build_toy_problem()
|
||||
statistics_ = convnet.minimize_loss_distributed(
|
||||
loss, accuracy, layer_collection = self._build_toy_problem()
|
||||
accuracy_ = convnet.minimize_loss_distributed(
|
||||
task_id=0,
|
||||
num_worker_tasks=1,
|
||||
num_ps_tasks=0,
|
||||
master="",
|
||||
checkpoint_dir=None,
|
||||
loss=loss,
|
||||
statistics=statistics,
|
||||
accuracy=accuracy,
|
||||
layer_collection=layer_collection)
|
||||
self.assertLess(statistics_["loss"], 1.0)
|
||||
self.assertLess(accuracy_, 1.0)
|
||||
|
||||
def testTrainMnistSingleMachine(self):
|
||||
with tf.Graph().as_default():
|
||||
|
|
@ -140,6 +140,12 @@ class ConvNetTest(tf.test.TestCase):
|
|||
convnet.train_mnist_single_machine(
|
||||
data_dir=None, num_epochs=1, use_fake_data=True)
|
||||
|
||||
def testTrainMnistMultitower(self):
|
||||
with tf.Graph().as_default():
|
||||
# Ensure model training doesn't crash.
|
||||
convnet.train_mnist_multitower(
|
||||
data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
|
||||
|
||||
def testTrainMnistDistributed(self):
|
||||
with tf.Graph().as_default():
|
||||
# Ensure model training doesn't crash.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user