Fix device colocation for KMeans in case of multiple parameter servers.

PiperOrigin-RevId: 157795360
This commit is contained in:
A. Unique TensorFlower 2017-06-01 19:36:22 -07:00 committed by TensorFlower Gardener
parent b659bc39f2
commit 07710014d2

View File

@ -164,9 +164,10 @@ class KMeans(object):
with ops.colocate_with(inp):
# Computes Euclidean distance. Note the first and third terms are
# broadcast additions.
squared_distance = (math_ops.reduce_sum(
math_ops.square(inp), 1, keep_dims=True) - 2 * math_ops.matmul(
inp, clusters, transpose_b=True) + array_ops.transpose(
squared_distance = (
math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) -
2 * math_ops.matmul(inp, clusters, transpose_b=True) +
array_ops.transpose(
math_ops.reduce_sum(
math_ops.square(clusters), 1, keep_dims=True)))
output.append(squared_distance)
@ -229,12 +230,12 @@ class KMeans(object):
clusters = nn_impl.l2_normalize(clusters, dim=1)
for inp, score in zip(inputs, scores):
with ops.colocate_with(inp):
(indices,
distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1)
(indices, distances) = gen_clustering_ops.nearest_neighbors(
inp, clusters, 1)
if self._distance_metric == COSINE_DISTANCE:
distances *= 0.5
output.append(
(score, array_ops.squeeze(distances), array_ops.squeeze(indices)))
output.append((score, array_ops.squeeze(distances),
array_ops.squeeze(indices)))
return zip(*output)
def _init_clusters_random(self):
@ -265,9 +266,7 @@ class KMeans(object):
(not self._use_mini_batch or
self._mini_batch_steps_per_iteration > 1))
def _initialize_clusters(self,
cluster_centers,
cluster_centers_initialized,
def _initialize_clusters(self, cluster_centers, cluster_centers_initialized,
cluster_centers_updated):
"""Returns an op to initialize the cluster centers."""
@ -294,21 +293,19 @@ class KMeans(object):
with ops.colocate_with(cluster_centers_initialized):
initialized = control_flow_ops.with_dependencies(
[clusters_init],
array_ops.identity(cluster_centers_initialized))
[clusters_init], array_ops.identity(cluster_centers_initialized))
with ops.colocate_with(cluster_centers):
assign_centers = state_ops.assign(cluster_centers, clusters_init,
validate_shape=False)
assign_centers = state_ops.assign(
cluster_centers, clusters_init, validate_shape=False)
if cluster_centers_updated != cluster_centers:
assign_centers = control_flow_ops.group(
assign_centers,
state_ops.assign(cluster_centers_updated, clusters_init,
assign_centers = control_flow_ops.group(assign_centers,
state_ops.assign(
cluster_centers_updated,
clusters_init,
validate_shape=False))
assign_centers = control_flow_ops.with_dependencies(
[assign_centers],
state_ops.assign(cluster_centers_initialized, True))
return control_flow_ops.cond(initialized,
control_flow_ops.no_op,
[assign_centers], state_ops.assign(cluster_centers_initialized, True))
return control_flow_ops.cond(initialized, control_flow_ops.no_op,
lambda: assign_centers).op
def _create_variables(self):
@ -327,19 +324,16 @@ class KMeans(object):
cluster_centers_updated back to cluster_centers.
"""
init_value = array_ops.constant([], dtype=dtypes.float32)
cluster_centers = variable_scope.variable(init_value,
name='clusters',
validate_shape=False)
cluster_centers_initialized = variable_scope.variable(False,
dtype=dtypes.bool,
name='initialized')
cluster_centers = variable_scope.variable(
init_value, name='clusters', validate_shape=False)
cluster_centers_initialized = variable_scope.variable(
False, dtype=dtypes.bool, name='initialized')
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
# Copy of cluster centers actively updated each step according to
# mini-batch update rule.
cluster_centers_updated = variable_scope.variable(init_value,
name='clusters_updated',
validate_shape=False)
cluster_centers_updated = variable_scope.variable(
init_value, name='clusters_updated', validate_shape=False)
# How many steps till we copy the updated clusters to cluster_centers.
update_in_steps = variable_scope.variable(
self._mini_batch_steps_per_iteration,
@ -347,20 +341,15 @@ class KMeans(object):
name='update_in_steps')
# Count of points assigned to cluster_centers_updated.
cluster_counts = variable_scope.variable(
array_ops.zeros([self._num_clusters],
dtype=dtypes.int64))
array_ops.zeros([self._num_clusters], dtype=dtypes.int64))
else:
cluster_centers_updated = cluster_centers
update_in_steps = None
cluster_counts = (variable_scope.variable(array_ops.ones(
[self._num_clusters],
dtype=dtypes.int64))
cluster_counts = (variable_scope.variable(
array_ops.ones([self._num_clusters], dtype=dtypes.int64))
if self._use_mini_batch else None)
return (cluster_centers,
cluster_centers_initialized,
cluster_counts,
cluster_centers_updated,
update_in_steps)
return (cluster_centers, cluster_centers_initialized, cluster_counts,
cluster_centers_updated, update_in_steps)
@classmethod
def _l2_normalize_data(cls, inputs):
@ -391,11 +380,8 @@ class KMeans(object):
"""
# Implementation of kmeans.
inputs = self._inputs
(cluster_centers_var,
cluster_centers_initialized,
total_counts,
cluster_centers_updated,
update_in_steps) = self._create_variables()
(cluster_centers_var, cluster_centers_initialized, total_counts,
cluster_centers_updated, update_in_steps) = self._create_variables()
init_op = self._initialize_clusters(cluster_centers_var,
cluster_centers_initialized,
cluster_centers_updated)
@ -409,8 +395,7 @@ class KMeans(object):
all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
if self._use_mini_batch:
sync_updates_op = self._mini_batch_sync_updates_op(
update_in_steps,
cluster_centers_var, cluster_centers_updated,
update_in_steps, cluster_centers_var, cluster_centers_updated,
total_counts)
assert sync_updates_op is not None
with ops.control_dependencies([sync_updates_op]):
@ -421,15 +406,15 @@ class KMeans(object):
training_op = self._full_batch_training_op(inputs, cluster_idx,
cluster_centers_var)
return (all_scores, cluster_idx, scores,
cluster_centers_initialized, init_op, training_op)
return (all_scores, cluster_idx, scores, cluster_centers_initialized,
init_op, training_op)
def _mini_batch_sync_updates_op(self, update_in_steps,
cluster_centers_var, cluster_centers_updated,
total_counts):
def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
cluster_centers_updated, total_counts):
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
assert update_in_steps is not None
with ops.colocate_with(update_in_steps):
def _f():
# Note that there is a race condition here, so we do a best effort
# updates here. We reset update_in_steps first so that other workers
@ -437,33 +422,36 @@ class KMeans(object):
# before resetting total_counts to avoid large updates to
# cluster_centers_updated based on partially updated
# cluster_center_vars.
with ops.control_dependencies([state_ops.assign(
update_in_steps,
self._mini_batch_steps_per_iteration - 1)]):
with ops.colocate_with(cluster_centers_updated):
with ops.control_dependencies([
state_ops.assign(update_in_steps,
self._mini_batch_steps_per_iteration - 1)
]):
with ops.colocate_with(
cluster_centers_updated, ignore_existing=True):
if self._distance_metric == COSINE_DISTANCE:
cluster_centers = nn_impl.l2_normalize(cluster_centers_updated,
dim=1)
cluster_centers = nn_impl.l2_normalize(
cluster_centers_updated, dim=1)
else:
cluster_centers = cluster_centers_updated
with ops.colocate_with(cluster_centers_var):
with ops.control_dependencies([state_ops.assign(
cluster_centers_var,
cluster_centers)]):
with ops.colocate_with(cluster_centers_var):
with ops.control_dependencies(
[state_ops.assign(cluster_centers_var, cluster_centers)]):
with ops.colocate_with(
cluster_centers_var, ignore_existing=True):
with ops.control_dependencies([
state_ops.assign(total_counts,
array_ops.zeros_like(total_counts))]):
array_ops.zeros_like(total_counts))
]):
return array_ops.identity(update_in_steps)
return control_flow_ops.cond(
update_in_steps <= 0,
_f,
update_in_steps <= 0, _f,
lambda: state_ops.assign_sub(update_in_steps, 1))
else:
return control_flow_ops.no_op()
def _mini_batch_training_op(self, inputs, cluster_idx_list,
cluster_centers, total_counts):
def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
total_counts):
"""Creates an op for training for mini batch case.
Args:
@ -487,17 +475,15 @@ class KMeans(object):
unique_ids, unique_idx = array_ops.unique(cluster_idx)
num_unique_cluster_idx = array_ops.size(unique_ids)
# Fetch the old values of counts and cluster_centers.
with ops.colocate_with(total_counts):
with ops.colocate_with(total_counts, ignore_existing=True):
old_counts = array_ops.gather(total_counts, unique_ids)
# TODO(agarwal): This colocation seems to run into problems. Fix it.
# with ops.colocate_with(cluster_centers):
with ops.colocate_with(cluster_centers, ignore_existing=True):
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
# Locally aggregate the increment to counts.
count_updates = math_ops.unsorted_segment_sum(
array_ops.ones_like(
unique_idx, dtype=total_counts.dtype),
unique_idx,
num_unique_cluster_idx)
array_ops.ones_like(unique_idx, dtype=total_counts.dtype),
unique_idx, num_unique_cluster_idx)
# Locally compute the sum of inputs mapped to each id.
# For a cluster with old cluster value x, old count n, and with data
# d_1,...d_k newly assigned to it, we recompute the new value as
@ -507,13 +493,12 @@ class KMeans(object):
inp, unique_idx, num_unique_cluster_idx)
# Shape to enable broadcasting count_updates and learning_rate to inp.
# It extends the shape with 1's to match the rank of inp.
broadcast_shape = array_ops.concat(
[
array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones(
broadcast_shape = array_ops.concat([
array_ops.reshape(num_unique_cluster_idx, [1]),
array_ops.ones(
array_ops.reshape(array_ops.rank(inp) - 1, [1]),
dtype=dtypes.int32)
],
0)
], 0)
# Subtract k * x, see comment above.
cluster_center_updates -= math_ops.cast(
array_ops.reshape(count_updates, broadcast_shape),
@ -524,14 +509,10 @@ class KMeans(object):
# scale by 1 / (n + k), see comment above.
cluster_center_updates *= learning_rate
# Apply the updates.
update_counts = state_ops.scatter_add(
total_counts,
unique_ids,
update_counts = state_ops.scatter_add(total_counts, unique_ids,
count_updates)
update_cluster_centers = state_ops.scatter_add(
cluster_centers,
unique_ids,
cluster_center_updates)
cluster_centers, unique_ids, cluster_center_updates)
update_ops.extend([update_counts, update_cluster_centers])
return control_flow_ops.group(*update_ops)
@ -552,7 +533,7 @@ class KMeans(object):
cluster_counts = []
epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
for inp, cluster_idx in zip(inputs, cluster_idx_list):
with ops.colocate_with(inp):
with ops.colocate_with(inp, ignore_existing=True):
cluster_sums.append(
math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters))
cluster_counts.append(
@ -561,7 +542,7 @@ class KMeans(object):
array_ops.ones(
array_ops.reshape(array_ops.shape(inp)[0], [-1])),
[-1, 1]), cluster_idx, self._num_clusters))
with ops.colocate_with(cluster_centers):
with ops.colocate_with(cluster_centers, ignore_existing=True):
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
if self._clusters_l2_normalized():