mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Remove /replica:0 declaration in device functions and allow them
to be freely bound based on cluster names present. When more than one value matches, it will choose the first lexicographically available device that matches the specification, which in practice will do pretty much the same thing as hardcoding /replica:0. PiperOrigin-RevId: 165766815
This commit is contained in:
parent
d685bbc54d
commit
575bd01d46
|
|
@ -881,7 +881,7 @@ class _EvalMetrics(object):
|
|||
|
||||
num_shards = run_config.tpu_config.num_shards
|
||||
job = _tpu_job(run_config)
|
||||
job_device = '' if job is None else ('/job:%s/replica:0' % job)
|
||||
job_device = '' if job is None else ('/job:%s' % job)
|
||||
|
||||
# For each i, dequeue_ops[i] is a list containing the tensors from all
|
||||
# shards. This list is concatenated later.
|
||||
|
|
@ -1144,7 +1144,7 @@ class TPUEstimator(estimator_lib.Estimator):
|
|||
if job is None:
|
||||
return '/replica:0/task:0/device:CPU:0'
|
||||
else:
|
||||
return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)
|
||||
return '/job:%s/task:%d/device:CPU:0' % (job, index / 8)
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
if not config.tpu_config.per_host_input_for_training:
|
||||
|
|
@ -1221,7 +1221,9 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder, run_config,
|
|||
if job is None:
|
||||
return '/replica:0/task:0/device:CPU:0'
|
||||
else:
|
||||
return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)
|
||||
# This assumes that if using more than 8 shards,
|
||||
# the job configuration varies 'task'.
|
||||
return '/job:%s/task:%d/device:CPU:0' % (job, index / 8)
|
||||
return infeed_queue.split_inputs_and_generate_enqueue_ops(
|
||||
unsharded_inputs, placement_function=placement_function)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user