Remove /replica:0 declaration in device functions and allow them

to be freely bound based on cluster names present.

When more than one value matches, it will choose the first
lexicographically available device that matches the specification,
which in practice will do pretty much the same thing as hardcoding
/replica:0.

PiperOrigin-RevId: 165766815
This commit is contained in:
Vijay Vasudevan 2017-08-18 16:15:46 -07:00 committed by TensorFlower Gardener
parent d685bbc54d
commit 575bd01d46

View File

@ -881,7 +881,7 @@ class _EvalMetrics(object):
num_shards = run_config.tpu_config.num_shards
job = _tpu_job(run_config)
job_device = '' if job is None else ('/job:%s/replica:0' % job)
job_device = '' if job is None else ('/job:%s' % job)
# For each i, dequeue_ops[i] is a list containing the tensors from all
# shards. This list is concatenated later.
@ -1144,7 +1144,7 @@ class TPUEstimator(estimator_lib.Estimator):
if job is None:
return '/replica:0/task:0/device:CPU:0'
else:
return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)
return '/job:%s/task:%d/device:CPU:0' % (job, index / 8)
if mode == model_fn_lib.ModeKeys.TRAIN:
if not config.tpu_config.per_host_input_for_training:
@ -1221,7 +1221,9 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder, run_config,
if job is None:
return '/replica:0/task:0/device:CPU:0'
else:
return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)
# This assumes that if using more than 8 shards,
# the job configuration varies 'task'.
return '/job:%s/task:%d/device:CPU:0' % (job, index / 8)
return infeed_queue.split_inputs_and_generate_enqueue_ops(
unsharded_inputs, placement_function=placement_function)