mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
K-FAC: Multi-tower ConvNet example.
PiperOrigin-RevId: 173982527
This commit is contained in:
parent
2ba5298565
commit
bb7ed1c889
|
|
@ -83,7 +83,7 @@ def conv_layer(layer_id, inputs, kernel_size, out_channels):
|
||||||
activations = tf.nn.relu(preactivations)
|
activations = tf.nn.relu(preactivations)
|
||||||
|
|
||||||
# layer.weights is a list. This converts it a (hashable) tuple.
|
# layer.weights is a list. This converts it a (hashable) tuple.
|
||||||
return preactivations, activations, tuple(layer.weights)
|
return preactivations, activations, (layer.kernel, layer.bias)
|
||||||
|
|
||||||
|
|
||||||
def max_pool_layer(layer_id, inputs, kernel_size, stride):
|
def max_pool_layer(layer_id, inputs, kernel_size, stride):
|
||||||
|
|
@ -128,7 +128,7 @@ def linear_layer(layer_id, inputs, output_size):
|
||||||
return pre, params
|
return pre, params
|
||||||
|
|
||||||
|
|
||||||
def build_model(examples, labels, num_labels, num_ps_tasks=0):
|
def build_model(examples, labels, num_labels, layer_collection):
|
||||||
"""Builds a ConvNet classification model.
|
"""Builds a ConvNet classification model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -137,65 +137,64 @@ def build_model(examples, labels, num_labels, num_ps_tasks=0):
|
||||||
labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
|
labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
|
||||||
by softmax for each example.
|
by softmax for each example.
|
||||||
num_labels: int. Number of distinct values 'labels' can take on.
|
num_labels: int. Number of distinct values 'labels' can take on.
|
||||||
num_ps_tasks: int. Number of parameter servers. If zero, variables
|
layer_collection: LayerCollection instance. Layers will be registered here.
|
||||||
will be placed locally.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
loss: 0-D Tensor representing loss to be minimized.
|
loss: 0-D Tensor representing loss to be minimized.
|
||||||
statistics: dict mapping strings to Tensors. Additional model evaluation
|
accuracy: 0-D Tensor representing model's accuracy.
|
||||||
statistics.
|
|
||||||
layer_collection: LayerCollection instance describing model architecture.
|
|
||||||
"""
|
"""
|
||||||
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
|
# Build a ConvNet. For each layer with parameters, we'll keep track of the
|
||||||
# Build a ConvNet. For each layer with parameters, we'll keep track of the
|
# preactivations, activations, weights, and bias.
|
||||||
# preactivations, activations, weights, and bias.
|
tf.logging.info("Building model.")
|
||||||
tf.logging.info("Building model.")
|
pre0, act0, params0 = conv_layer(
|
||||||
pre0, act0, params0 = conv_layer(
|
layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
|
||||||
layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
|
act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
|
||||||
act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
|
pre2, act2, params2 = conv_layer(
|
||||||
pre2, act2, params2 = conv_layer(
|
layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
|
||||||
layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
|
act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
|
||||||
act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
|
flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
|
||||||
flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
|
logits, params4 = linear_layer(
|
||||||
logits, params4 = linear_layer(
|
layer_id=4, inputs=flat_act3, output_size=num_labels)
|
||||||
layer_id=4, inputs=flat_act3, output_size=num_labels)
|
loss = tf.reduce_mean(
|
||||||
loss = tf.reduce_mean(
|
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||||
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
labels=labels, logits=logits))
|
||||||
labels=labels, logits=logits))
|
accuracy = tf.reduce_mean(
|
||||||
accuracy = tf.reduce_mean(
|
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
|
||||||
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
|
|
||||||
|
|
||||||
tf.summary.scalar("loss", loss)
|
tf.summary.scalar("loss", loss)
|
||||||
tf.summary.scalar("accuracy", accuracy)
|
tf.summary.scalar("accuracy", accuracy)
|
||||||
|
|
||||||
# Register parameters. K-FAC needs to know about the inputs, outputs, and
|
# Register parameters. K-FAC needs to know about the inputs, outputs, and
|
||||||
# parameters of each conv/fully connected layer and the logits powering the
|
# parameters of each conv/fully connected layer and the logits powering the
|
||||||
# posterior probability over classes.
|
# posterior probability over classes.
|
||||||
tf.logging.info("Building KFAC Optimizer.")
|
tf.logging.info("Building LayerCollection.")
|
||||||
layer_collection = lc.LayerCollection()
|
layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
|
||||||
layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
|
pre0)
|
||||||
pre0)
|
layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
|
||||||
layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
|
layer_collection.register_fully_connected(params4, flat_act3, logits)
|
||||||
layer_collection.register_fully_connected(params4, flat_act3, logits)
|
layer_collection.register_categorical_predictive_distribution(
|
||||||
layer_collection.register_categorical_predictive_distribution(logits)
|
logits, name="logits")
|
||||||
|
|
||||||
return loss, {"accuracy": accuracy}, layer_collection
|
return loss, accuracy
|
||||||
|
|
||||||
|
|
||||||
def minimize_loss_single_machine(loss, statistics, layer_collection):
|
def minimize_loss_single_machine(loss,
|
||||||
|
accuracy,
|
||||||
|
layer_collection,
|
||||||
|
session_config=None):
|
||||||
"""Minimize loss with K-FAC on a single machine.
|
"""Minimize loss with K-FAC on a single machine.
|
||||||
|
|
||||||
A single Session is responsible for running all of K-FAC's ops.
|
A single Session is responsible for running all of K-FAC's ops.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loss: 0-D Tensor. Loss to be minimized.
|
loss: 0-D Tensor. Loss to be minimized.
|
||||||
statistics: dict mapping strings to 0-D Tensors. Additional statistics to
|
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
|
||||||
run with each step.
|
|
||||||
layer_collection: LayerCollection instance describing model architecture.
|
layer_collection: LayerCollection instance describing model architecture.
|
||||||
Used by K-FAC to construct preconditioner.
|
Used by K-FAC to construct preconditioner.
|
||||||
|
session_config: None or tf.ConfigProto. Configuration for tf.Session().
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
final value for 'statistics'.
|
final value for 'accuracy'.
|
||||||
"""
|
"""
|
||||||
# Train with K-FAC.
|
# Train with K-FAC.
|
||||||
global_step = tf.train.get_or_create_global_step()
|
global_step = tf.train.get_or_create_global_step()
|
||||||
|
|
@ -208,19 +207,19 @@ def minimize_loss_single_machine(loss, statistics, layer_collection):
|
||||||
train_op = optimizer.minimize(loss, global_step=global_step)
|
train_op = optimizer.minimize(loss, global_step=global_step)
|
||||||
|
|
||||||
tf.logging.info("Starting training.")
|
tf.logging.info("Starting training.")
|
||||||
with tf.train.MonitoredTrainingSession() as sess:
|
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
|
||||||
while not sess.should_stop():
|
while not sess.should_stop():
|
||||||
global_step_, loss_, statistics_, _, _ = sess.run(
|
global_step_, loss_, accuracy_, _, _ = sess.run(
|
||||||
[global_step, loss, statistics, train_op, optimizer.cov_update_op])
|
[global_step, loss, accuracy, train_op, optimizer.cov_update_op])
|
||||||
|
|
||||||
if global_step_ % 100 == 0:
|
if global_step_ % 100 == 0:
|
||||||
sess.run(optimizer.inv_update_op)
|
sess.run(optimizer.inv_update_op)
|
||||||
|
|
||||||
if global_step_ % 100 == 0:
|
if global_step_ % 100 == 0:
|
||||||
tf.logging.info("global_step: %d | loss: %f | %s", global_step_, loss_,
|
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
|
||||||
statistics_)
|
global_step_, loss_, accuracy_)
|
||||||
|
|
||||||
return statistics_
|
return accuracy_
|
||||||
|
|
||||||
|
|
||||||
def _is_gradient_task(task_id, num_tasks):
|
def _is_gradient_task(task_id, num_tasks):
|
||||||
|
|
@ -252,8 +251,7 @@ def _num_gradient_tasks(num_tasks):
|
||||||
|
|
||||||
|
|
||||||
def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
|
def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
|
||||||
checkpoint_dir, loss, statistics,
|
checkpoint_dir, loss, accuracy, layer_collection):
|
||||||
layer_collection):
|
|
||||||
"""Minimize loss with an synchronous implementation of K-FAC.
|
"""Minimize loss with an synchronous implementation of K-FAC.
|
||||||
|
|
||||||
Different tasks are responsible for different parts of K-FAC's Ops. The first
|
Different tasks are responsible for different parts of K-FAC's Ops. The first
|
||||||
|
|
@ -269,13 +267,13 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
|
||||||
string to run locally.
|
string to run locally.
|
||||||
checkpoint_dir: string or None. Path to store checkpoints under.
|
checkpoint_dir: string or None. Path to store checkpoints under.
|
||||||
loss: 0-D Tensor. Loss to be minimized.
|
loss: 0-D Tensor. Loss to be minimized.
|
||||||
statistics: dict mapping strings to 0-D Tensors. Additional statistics to
|
accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
|
||||||
run with each step.
|
run with each step.
|
||||||
layer_collection: LayerCollection instance describing model architecture.
|
layer_collection: LayerCollection instance describing model architecture.
|
||||||
Used by K-FAC to construct preconditioner.
|
Used by K-FAC to construct preconditioner.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
final value for 'statistics'.
|
final value for 'accuracy'.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if task_id >= num_worker_tasks.
|
ValueError: if task_id >= num_worker_tasks.
|
||||||
|
|
@ -318,12 +316,12 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
|
||||||
else:
|
else:
|
||||||
raise ValueError("Which op should task %d do?" % task_id)
|
raise ValueError("Which op should task %d do?" % task_id)
|
||||||
|
|
||||||
global_step_, loss_, statistics_, _ = sess.run(
|
global_step_, loss_, accuracy_, _ = sess.run(
|
||||||
[global_step, loss, statistics, learning_op])
|
[global_step, loss, accuracy, learning_op])
|
||||||
tf.logging.info("global_step: %d | loss: %f | %s", global_step_, loss_,
|
tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
|
||||||
statistics_)
|
loss_, accuracy_)
|
||||||
|
|
||||||
return statistics_
|
return accuracy_
|
||||||
|
|
||||||
|
|
||||||
def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
|
def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
|
||||||
|
|
@ -347,11 +345,69 @@ def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
|
||||||
flatten_images=False)
|
flatten_images=False)
|
||||||
|
|
||||||
# Build a ConvNet.
|
# Build a ConvNet.
|
||||||
loss, statistics, layer_collection = build_model(
|
layer_collection = lc.LayerCollection()
|
||||||
examples, labels, num_labels=10)
|
loss, accuracy = build_model(
|
||||||
|
examples, labels, num_labels=10, layer_collection=layer_collection)
|
||||||
|
|
||||||
# Fit model.
|
# Fit model.
|
||||||
return minimize_loss_single_machine(loss, statistics, layer_collection)
|
return minimize_loss_single_machine(loss, accuracy, layer_collection)
|
||||||
|
|
||||||
|
|
||||||
|
def train_mnist_multitower(data_dir, num_epochs, num_towers,
|
||||||
|
use_fake_data=True):
|
||||||
|
"""Train a ConvNet on MNIST.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dir: string. Directory to read MNIST examples from.
|
||||||
|
num_epochs: int. Number of passes to make over the training set.
|
||||||
|
num_towers: int. Number of CPUs to split inference across.
|
||||||
|
use_fake_data: bool. If True, generate a synthetic dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
accuracy of model on the final minibatch of training data.
|
||||||
|
"""
|
||||||
|
# Load a dataset.
|
||||||
|
tf.logging.info("Loading MNIST into memory.")
|
||||||
|
tower_batch_size = 128
|
||||||
|
batch_size = tower_batch_size * num_towers
|
||||||
|
tf.logging.info(
|
||||||
|
("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
|
||||||
|
"tower batch size.") % (batch_size, num_towers, tower_batch_size))
|
||||||
|
examples, labels = mnist.load_mnist(
|
||||||
|
data_dir,
|
||||||
|
num_epochs=num_epochs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
use_fake_data=use_fake_data,
|
||||||
|
flatten_images=False)
|
||||||
|
|
||||||
|
# Split minibatch across towers.
|
||||||
|
examples = tf.split(examples, num_towers)
|
||||||
|
labels = tf.split(labels, num_towers)
|
||||||
|
|
||||||
|
# Build an MLP. Each tower's layers will be added to the LayerCollection.
|
||||||
|
layer_collection = lc.LayerCollection()
|
||||||
|
tower_results = []
|
||||||
|
for tower_id in range(num_towers):
|
||||||
|
with tf.device("/cpu:%d" % tower_id):
|
||||||
|
with tf.name_scope("tower%d" % tower_id):
|
||||||
|
with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
|
||||||
|
tf.logging.info("Building tower %d." % tower_id)
|
||||||
|
tower_results.append(
|
||||||
|
build_model(examples[tower_id], labels[tower_id], 10,
|
||||||
|
layer_collection))
|
||||||
|
losses, accuracies = zip(*tower_results)
|
||||||
|
|
||||||
|
# Average across towers.
|
||||||
|
loss = tf.reduce_mean(losses)
|
||||||
|
accuracy = tf.reduce_mean(accuracies)
|
||||||
|
|
||||||
|
# Fit model.
|
||||||
|
session_config = tf.ConfigProto(
|
||||||
|
allow_soft_placement=False, device_count={
|
||||||
|
"CPU": num_towers
|
||||||
|
})
|
||||||
|
return minimize_loss_single_machine(
|
||||||
|
loss, accuracy, layer_collection, session_config=session_config)
|
||||||
|
|
||||||
|
|
||||||
def train_mnist_distributed(task_id,
|
def train_mnist_distributed(task_id,
|
||||||
|
|
@ -385,13 +441,15 @@ def train_mnist_distributed(task_id,
|
||||||
flatten_images=False)
|
flatten_images=False)
|
||||||
|
|
||||||
# Build a ConvNet.
|
# Build a ConvNet.
|
||||||
loss, statistics, layer_collection = build_model(
|
layer_collection = lc.LayerCollection()
|
||||||
examples, labels, num_labels=10, num_ps_tasks=num_ps_tasks)
|
with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
|
||||||
|
loss, accuracy = build_model(
|
||||||
|
examples, labels, num_labels=10, layer_collection=layer_collection)
|
||||||
|
|
||||||
# Fit model.
|
# Fit model.
|
||||||
checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
|
checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
|
||||||
return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks,
|
return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks,
|
||||||
master, checkpoint_dir, loss, statistics,
|
master, checkpoint_dir, loss, accuracy,
|
||||||
layer_collection)
|
layer_collection)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,12 @@ FLAGS = None
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
_ = argv
|
_ = argv
|
||||||
convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
|
|
||||||
|
if FLAGS.num_towers > 1:
|
||||||
|
convnet.train_mnist_multitower(
|
||||||
|
FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
|
||||||
|
else:
|
||||||
|
convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
@ -43,5 +48,10 @@ if __name__ == "__main__":
|
||||||
type=str,
|
type=str,
|
||||||
default="/tmp/mnist",
|
default="/tmp/mnist",
|
||||||
help="Directory to store dataset in.")
|
help="Directory to store dataset in.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_towers",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of CPUs to split minibatch across.")
|
||||||
FLAGS, unparsed = parser.parse_known_args()
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||||
|
|
|
||||||
|
|
@ -66,8 +66,9 @@ class ConvNetTest(tf.test.TestCase):
|
||||||
with tf.Graph().as_default():
|
with tf.Graph().as_default():
|
||||||
x = tf.placeholder(tf.float32, [None, 6, 6, 3])
|
x = tf.placeholder(tf.float32, [None, 6, 6, 3])
|
||||||
y = tf.placeholder(tf.int64, [None])
|
y = tf.placeholder(tf.int64, [None])
|
||||||
loss, statistics, layer_collection = convnet.build_model(
|
layer_collection = lc.LayerCollection()
|
||||||
x, y, num_labels=5)
|
loss, accuracy = convnet.build_model(
|
||||||
|
x, y, num_labels=5, layer_collection=layer_collection)
|
||||||
|
|
||||||
# Ensure layers and logits were registered.
|
# Ensure layers and logits were registered.
|
||||||
self.assertEqual(len(layer_collection.fisher_blocks), 3)
|
self.assertEqual(len(layer_collection.fisher_blocks), 3)
|
||||||
|
|
@ -80,7 +81,7 @@ class ConvNetTest(tf.test.TestCase):
|
||||||
x: np.random.randn(10, 6, 6, 3).astype(np.float32),
|
x: np.random.randn(10, 6, 6, 3).astype(np.float32),
|
||||||
y: np.random.randint(5, size=10).astype(np.int64),
|
y: np.random.randint(5, size=10).astype(np.int64),
|
||||||
}
|
}
|
||||||
sess.run([loss, statistics], feed_dict=feed_dict)
|
sess.run([loss, accuracy], feed_dict=feed_dict)
|
||||||
|
|
||||||
def _build_toy_problem(self):
|
def _build_toy_problem(self):
|
||||||
"""Construct a toy linear regression problem.
|
"""Construct a toy linear regression problem.
|
||||||
|
|
@ -90,8 +91,7 @@ class ConvNetTest(tf.test.TestCase):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
loss: 0-D Tensor representing loss to be minimized.
|
loss: 0-D Tensor representing loss to be minimized.
|
||||||
statistics: dict mapping strings to Tensors. Additional model evaluation
|
accuracy: 0-D Tensors representing model accuracy.
|
||||||
statistics.
|
|
||||||
layer_collection: LayerCollection instance describing model architecture.
|
layer_collection: LayerCollection instance describing model architecture.
|
||||||
"""
|
"""
|
||||||
x = np.asarray([[1.], [2.]]).astype(np.float32)
|
x = np.asarray([[1.], [2.]]).astype(np.float32)
|
||||||
|
|
@ -101,34 +101,34 @@ class ConvNetTest(tf.test.TestCase):
|
||||||
w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer())
|
w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer())
|
||||||
y_hat = tf.matmul(x, w)
|
y_hat = tf.matmul(x, w)
|
||||||
loss = tf.reduce_mean(0.5 * tf.square(y_hat - y))
|
loss = tf.reduce_mean(0.5 * tf.square(y_hat - y))
|
||||||
statistics = {"loss": loss}
|
accuracy = loss
|
||||||
|
|
||||||
layer_collection = lc.LayerCollection()
|
layer_collection = lc.LayerCollection()
|
||||||
layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat)
|
layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat)
|
||||||
layer_collection.register_normal_predictive_distribution(y_hat)
|
layer_collection.register_normal_predictive_distribution(y_hat)
|
||||||
|
|
||||||
return loss, statistics, layer_collection
|
return loss, accuracy, layer_collection
|
||||||
|
|
||||||
def testMinimizeLossSingleMachine(self):
|
def testMinimizeLossSingleMachine(self):
|
||||||
with tf.Graph().as_default():
|
with tf.Graph().as_default():
|
||||||
loss, statistics, layer_collection = self._build_toy_problem()
|
loss, accuracy, layer_collection = self._build_toy_problem()
|
||||||
statistics_ = convnet.minimize_loss_single_machine(
|
accuracy_ = convnet.minimize_loss_single_machine(loss, accuracy,
|
||||||
loss, statistics, layer_collection)
|
layer_collection)
|
||||||
self.assertLess(statistics_["loss"], 1.0)
|
self.assertLess(accuracy_, 1.0)
|
||||||
|
|
||||||
def testMinimizeLossDistributed(self):
|
def testMinimizeLossDistributed(self):
|
||||||
with tf.Graph().as_default():
|
with tf.Graph().as_default():
|
||||||
loss, statistics, layer_collection = self._build_toy_problem()
|
loss, accuracy, layer_collection = self._build_toy_problem()
|
||||||
statistics_ = convnet.minimize_loss_distributed(
|
accuracy_ = convnet.minimize_loss_distributed(
|
||||||
task_id=0,
|
task_id=0,
|
||||||
num_worker_tasks=1,
|
num_worker_tasks=1,
|
||||||
num_ps_tasks=0,
|
num_ps_tasks=0,
|
||||||
master="",
|
master="",
|
||||||
checkpoint_dir=None,
|
checkpoint_dir=None,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
statistics=statistics,
|
accuracy=accuracy,
|
||||||
layer_collection=layer_collection)
|
layer_collection=layer_collection)
|
||||||
self.assertLess(statistics_["loss"], 1.0)
|
self.assertLess(accuracy_, 1.0)
|
||||||
|
|
||||||
def testTrainMnistSingleMachine(self):
|
def testTrainMnistSingleMachine(self):
|
||||||
with tf.Graph().as_default():
|
with tf.Graph().as_default():
|
||||||
|
|
@ -140,6 +140,12 @@ class ConvNetTest(tf.test.TestCase):
|
||||||
convnet.train_mnist_single_machine(
|
convnet.train_mnist_single_machine(
|
||||||
data_dir=None, num_epochs=1, use_fake_data=True)
|
data_dir=None, num_epochs=1, use_fake_data=True)
|
||||||
|
|
||||||
|
def testTrainMnistMultitower(self):
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
# Ensure model training doesn't crash.
|
||||||
|
convnet.train_mnist_multitower(
|
||||||
|
data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
|
||||||
|
|
||||||
def testTrainMnistDistributed(self):
|
def testTrainMnistDistributed(self):
|
||||||
with tf.Graph().as_default():
|
with tf.Graph().as_default():
|
||||||
# Ensure model training doesn't crash.
|
# Ensure model training doesn't crash.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user