Merge pull request #21709 from saeta/fix_tpu

Cherry-picks for Keras on Cloud TPUs
This commit is contained in:
Amit Patankar 2018-08-23 12:58:44 -07:00 committed by GitHub
commit 4dcfddc5d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -54,7 +54,7 @@ import time
import numpy as np
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
@ -80,12 +80,54 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_inspect
_SESSIONS = {}
def tpu_session(cluster_resolver):
"""Construct or return a `tf.Session` connected to the given cluster."""
global _SESSIONS
master = cluster_resolver.master()
if master not in _SESSIONS:
cluster_spec = cluster_resolver.cluster_spec()
config = config_pb2.ConfigProto(isolate_session_state=True)
if cluster_spec:
config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
graph = ops.Graph()
session = tf_session.Session(graph=graph, target=master, config=config)
with graph.as_default():
session.run(tpu.initialize_system())
_SESSIONS[master] = session
return _SESSIONS[master]
def reset_tpu_sessions():
_SESSIONS.clear()
# Work-around dependency cycle between DistributionStrategy and TPU lib.
def TPUDistributionStrategy(*args, **kw): # pylint: disable=invalid-name
def TPUDistributionStrategy(tpu_cluster_resolver=None): # pylint: disable=invalid-name
"""Construct a TPUDistributionStrategy."""
from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
return tpu_strategy.TPUStrategy(*args, **kw)
# TODO -- remove this when TPUStrategy API is consistent (b/112705069)
if tpu_cluster_resolver is None:
tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__)
if len(args) == 3:
logging.info('Detected new TPUStrategy API.')
return tpu_strategy.TPUStrategy(tpu_cluster_resolver, steps_per_run=1)
else:
logging.info('Detected old TPUStrategy API.')
strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8)
strategy._tpu_cluster_resolver = tpu_cluster_resolver
return strategy
class TPUEmbedding(embeddings.Embedding):
@ -663,9 +705,10 @@ class TPUFunction(object):
# Clone our CPU model, running within the TPU device context.
with TPURewriteContext(tpu_input_map):
# TODO(power): Replicate variables.
with ops.device('/device:TPU:0'):
self._cloned_model = models.clone_model(self.model)
with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
# TODO(power): Replicate variables.
with ops.device('/device:TPU:0'):
self._cloned_model = models.clone_model(self.model)
# Create a copy of the optimizer for this graph.
if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
@ -841,7 +884,7 @@ class TPUFunction(object):
class KerasTPUModel(models.Model):
"""TPU compatible Keras model wrapper."""
def __init__(self, cpu_model, tpu_name_or_address, strategy):
def __init__(self, cpu_model, strategy):
super(models.Model, self).__init__( # pylint: disable=bad-super-call
inputs=cpu_model.inputs,
outputs=cpu_model.outputs,
@ -858,27 +901,14 @@ class KerasTPUModel(models.Model):
self.train_function = None
self._strategy = strategy
self._tpu_name_or_address = tpu_name_or_address
cluster_resolver = self._strategy._tpu_cluster_resolver
self._tpu_name_or_address = cluster_resolver.get_master()
self._cpu_model = cpu_model
self._tpu_model = None
self._tpu_weights_initialized = False
self._graph = ops.Graph()
self._cluster_resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu_name_or_address)
master = self._cluster_resolver.master()
cluster_spec = self._cluster_resolver.cluster_spec()
self._session = tf_session.Session(
graph=self._graph,
target=master,
config=config_pb2.ConfigProto(isolate_session_state=True))
# TODO(saeta): Confirm the lines below work in ClusterSpec propagation env.
if cluster_spec:
self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
with self._graph.as_default():
self._session.run(tpu.initialize_system())
self._session = tpu_session(cluster_resolver)
self._graph = self._session.graph
# If the input CPU model has already been compiled, compile our TPU model
# immediately.
@ -1133,7 +1163,7 @@ Output shape: %(output_shape)s
@experimental
def tpu_model(model, tpu_name_or_address=None, strategy=None):
def tpu_model(model, strategy=None):
"""Copy `model` along with weights to the TPU. Returns a TPU model.
Usage:
@ -1144,7 +1174,7 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
# If `num_cores_per_host` is greater than one, batch parallelism will be used
# to run on multiple TPU cores.
strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8)
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
model = keras_support.tpu_model(model, strategy)
model.compile(
optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
@ -1154,10 +1184,6 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
Args:
model: A `KerasTPUModel`.
tpu_name_or_address: A string that is either the name of the Cloud TPU,
the grpc address of the Cloud TPU, or (Googlers only) the BNS name of the
Cloud TPU. If tpu_name_or_address is None, the TPUClusterResolver will
examine the environment to determine a potential Cloud TPU to use.
strategy: `TPUDistributionStrategy`. The strategy to use for replicating
model across multiple TPU cores.
@ -1172,9 +1198,8 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
# TODO(xiejw): Validate TPU model. TPUModel only?
# TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
# TODO(xiejw): Adds reduction option.
if strategy is None:
strategy = TPUDistributionStrategy(num_cores_per_host=1)
return KerasTPUModel(
cpu_model=model,
tpu_name_or_address=tpu_name_or_address,
strategy=strategy)
strategy = TPUDistributionStrategy()
return KerasTPUModel(cpu_model=model, strategy=strategy)