mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge pull request #21709 from saeta/fix_tpu
Cherry-picks for Keras on Cloud TPUs
This commit is contained in:
commit
4dcfddc5d1
|
|
@ -54,7 +54,7 @@ import time
|
|||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib
|
||||
from tensorflow.contrib.framework.python.framework import experimental
|
||||
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
|
|
@ -80,12 +80,54 @@ from tensorflow.python.ops import math_ops
|
|||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
_SESSIONS = {}
|
||||
|
||||
|
||||
def tpu_session(cluster_resolver):
|
||||
"""Construct or return a `tf.Session` connected to the given cluster."""
|
||||
global _SESSIONS
|
||||
master = cluster_resolver.master()
|
||||
if master not in _SESSIONS:
|
||||
cluster_spec = cluster_resolver.cluster_spec()
|
||||
config = config_pb2.ConfigProto(isolate_session_state=True)
|
||||
if cluster_spec:
|
||||
config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
|
||||
|
||||
graph = ops.Graph()
|
||||
session = tf_session.Session(graph=graph, target=master, config=config)
|
||||
|
||||
with graph.as_default():
|
||||
session.run(tpu.initialize_system())
|
||||
|
||||
_SESSIONS[master] = session
|
||||
return _SESSIONS[master]
|
||||
|
||||
|
||||
def reset_tpu_sessions():
|
||||
_SESSIONS.clear()
|
||||
|
||||
|
||||
# Work-around dependency cycle between DistributionStrategy and TPU lib.
|
||||
def TPUDistributionStrategy(*args, **kw): # pylint: disable=invalid-name
|
||||
def TPUDistributionStrategy(tpu_cluster_resolver=None): # pylint: disable=invalid-name
|
||||
"""Construct a TPUDistributionStrategy."""
|
||||
from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
|
||||
return tpu_strategy.TPUStrategy(*args, **kw)
|
||||
# TODO -- remove this when TPUStrategy API is consistent (b/112705069)
|
||||
if tpu_cluster_resolver is None:
|
||||
tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
|
||||
|
||||
args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__)
|
||||
if len(args) == 3:
|
||||
logging.info('Detected new TPUStrategy API.')
|
||||
return tpu_strategy.TPUStrategy(tpu_cluster_resolver, steps_per_run=1)
|
||||
else:
|
||||
logging.info('Detected old TPUStrategy API.')
|
||||
strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8)
|
||||
strategy._tpu_cluster_resolver = tpu_cluster_resolver
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
class TPUEmbedding(embeddings.Embedding):
|
||||
|
|
@ -663,6 +705,7 @@ class TPUFunction(object):
|
|||
|
||||
# Clone our CPU model, running within the TPU device context.
|
||||
with TPURewriteContext(tpu_input_map):
|
||||
with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
|
||||
# TODO(power): Replicate variables.
|
||||
with ops.device('/device:TPU:0'):
|
||||
self._cloned_model = models.clone_model(self.model)
|
||||
|
|
@ -841,7 +884,7 @@ class TPUFunction(object):
|
|||
class KerasTPUModel(models.Model):
|
||||
"""TPU compatible Keras model wrapper."""
|
||||
|
||||
def __init__(self, cpu_model, tpu_name_or_address, strategy):
|
||||
def __init__(self, cpu_model, strategy):
|
||||
super(models.Model, self).__init__( # pylint: disable=bad-super-call
|
||||
inputs=cpu_model.inputs,
|
||||
outputs=cpu_model.outputs,
|
||||
|
|
@ -858,27 +901,14 @@ class KerasTPUModel(models.Model):
|
|||
self.train_function = None
|
||||
self._strategy = strategy
|
||||
|
||||
self._tpu_name_or_address = tpu_name_or_address
|
||||
cluster_resolver = self._strategy._tpu_cluster_resolver
|
||||
self._tpu_name_or_address = cluster_resolver.get_master()
|
||||
self._cpu_model = cpu_model
|
||||
self._tpu_model = None
|
||||
self._tpu_weights_initialized = False
|
||||
self._graph = ops.Graph()
|
||||
|
||||
self._cluster_resolver = tpu_cluster_resolver.TPUClusterResolver(
|
||||
tpu_name_or_address)
|
||||
master = self._cluster_resolver.master()
|
||||
cluster_spec = self._cluster_resolver.cluster_spec()
|
||||
self._session = tf_session.Session(
|
||||
graph=self._graph,
|
||||
target=master,
|
||||
config=config_pb2.ConfigProto(isolate_session_state=True))
|
||||
|
||||
# TODO(saeta): Confirm the lines below work in ClusterSpec propagation env.
|
||||
if cluster_spec:
|
||||
self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
|
||||
|
||||
with self._graph.as_default():
|
||||
self._session.run(tpu.initialize_system())
|
||||
self._session = tpu_session(cluster_resolver)
|
||||
self._graph = self._session.graph
|
||||
|
||||
# If the input CPU model has already been compiled, compile our TPU model
|
||||
# immediately.
|
||||
|
|
@ -1133,7 +1163,7 @@ Output shape: %(output_shape)s
|
|||
|
||||
|
||||
@experimental
|
||||
def tpu_model(model, tpu_name_or_address=None, strategy=None):
|
||||
def tpu_model(model, strategy=None):
|
||||
"""Copy `model` along with weights to the TPU. Returns a TPU model.
|
||||
|
||||
Usage:
|
||||
|
|
@ -1144,7 +1174,7 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
|
|||
|
||||
# If `num_cores_per_host` is greater than one, batch parallelism will be used
|
||||
# to run on multiple TPU cores.
|
||||
strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8)
|
||||
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
|
||||
model = keras_support.tpu_model(model, strategy)
|
||||
model.compile(
|
||||
optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
|
||||
|
|
@ -1154,10 +1184,6 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
|
|||
|
||||
Args:
|
||||
model: A `KerasTPUModel`.
|
||||
tpu_name_or_address: A string that is either the name of the Cloud TPU,
|
||||
the grpc address of the Cloud TPU, or (Googlers only) the BNS name of the
|
||||
Cloud TPU. If tpu_name_or_address is None, the TPUClusterResolver will
|
||||
examine the environment to determine a potential Cloud TPU to use.
|
||||
strategy: `TPUDistributionStrategy`. The strategy to use for replicating
|
||||
model across multiple TPU cores.
|
||||
|
||||
|
|
@ -1172,9 +1198,8 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
|
|||
# TODO(xiejw): Validate TPU model. TPUModel only?
|
||||
# TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
|
||||
# TODO(xiejw): Adds reduction option.
|
||||
|
||||
if strategy is None:
|
||||
strategy = TPUDistributionStrategy(num_cores_per_host=1)
|
||||
return KerasTPUModel(
|
||||
cpu_model=model,
|
||||
tpu_name_or_address=tpu_name_or_address,
|
||||
strategy=strategy)
|
||||
strategy = TPUDistributionStrategy()
|
||||
|
||||
return KerasTPUModel(cpu_model=model, strategy=strategy)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user