Fix device colocation for KMeans in case of multiple parameter servers.

PiperOrigin-RevId: 157795360
This commit is contained in:
A. Unique TensorFlower 2017-06-01 19:36:22 -07:00 committed by TensorFlower Gardener
parent b659bc39f2
commit 07710014d2

View File

@ -164,11 +164,12 @@ class KMeans(object):
with ops.colocate_with(inp): with ops.colocate_with(inp):
# Computes Euclidean distance. Note the first and third terms are # Computes Euclidean distance. Note the first and third terms are
# broadcast additions. # broadcast additions.
squared_distance = (math_ops.reduce_sum( squared_distance = (
math_ops.square(inp), 1, keep_dims=True) - 2 * math_ops.matmul( math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) -
inp, clusters, transpose_b=True) + array_ops.transpose( 2 * math_ops.matmul(inp, clusters, transpose_b=True) +
math_ops.reduce_sum( array_ops.transpose(
math_ops.square(clusters), 1, keep_dims=True))) math_ops.reduce_sum(
math_ops.square(clusters), 1, keep_dims=True)))
output.append(squared_distance) output.append(squared_distance)
return output return output
@ -229,12 +230,12 @@ class KMeans(object):
clusters = nn_impl.l2_normalize(clusters, dim=1) clusters = nn_impl.l2_normalize(clusters, dim=1)
for inp, score in zip(inputs, scores): for inp, score in zip(inputs, scores):
with ops.colocate_with(inp): with ops.colocate_with(inp):
(indices, (indices, distances) = gen_clustering_ops.nearest_neighbors(
distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1) inp, clusters, 1)
if self._distance_metric == COSINE_DISTANCE: if self._distance_metric == COSINE_DISTANCE:
distances *= 0.5 distances *= 0.5
output.append( output.append((score, array_ops.squeeze(distances),
(score, array_ops.squeeze(distances), array_ops.squeeze(indices))) array_ops.squeeze(indices)))
return zip(*output) return zip(*output)
def _init_clusters_random(self): def _init_clusters_random(self):
@ -265,9 +266,7 @@ class KMeans(object):
(not self._use_mini_batch or (not self._use_mini_batch or
self._mini_batch_steps_per_iteration > 1)) self._mini_batch_steps_per_iteration > 1))
def _initialize_clusters(self, def _initialize_clusters(self, cluster_centers, cluster_centers_initialized,
cluster_centers,
cluster_centers_initialized,
cluster_centers_updated): cluster_centers_updated):
"""Returns an op to initialize the cluster centers.""" """Returns an op to initialize the cluster centers."""
@ -294,22 +293,20 @@ class KMeans(object):
with ops.colocate_with(cluster_centers_initialized): with ops.colocate_with(cluster_centers_initialized):
initialized = control_flow_ops.with_dependencies( initialized = control_flow_ops.with_dependencies(
[clusters_init], [clusters_init], array_ops.identity(cluster_centers_initialized))
array_ops.identity(cluster_centers_initialized))
with ops.colocate_with(cluster_centers): with ops.colocate_with(cluster_centers):
assign_centers = state_ops.assign(cluster_centers, clusters_init, assign_centers = state_ops.assign(
validate_shape=False) cluster_centers, clusters_init, validate_shape=False)
if cluster_centers_updated != cluster_centers: if cluster_centers_updated != cluster_centers:
assign_centers = control_flow_ops.group( assign_centers = control_flow_ops.group(assign_centers,
assign_centers, state_ops.assign(
state_ops.assign(cluster_centers_updated, clusters_init, cluster_centers_updated,
validate_shape=False)) clusters_init,
assign_centers = control_flow_ops.with_dependencies( validate_shape=False))
[assign_centers], assign_centers = control_flow_ops.with_dependencies(
state_ops.assign(cluster_centers_initialized, True)) [assign_centers], state_ops.assign(cluster_centers_initialized, True))
return control_flow_ops.cond(initialized, return control_flow_ops.cond(initialized, control_flow_ops.no_op,
control_flow_ops.no_op, lambda: assign_centers).op
lambda: assign_centers).op
def _create_variables(self): def _create_variables(self):
"""Creates variables. """Creates variables.
@ -327,19 +324,16 @@ class KMeans(object):
cluster_centers_updated back to cluster_centers. cluster_centers_updated back to cluster_centers.
""" """
init_value = array_ops.constant([], dtype=dtypes.float32) init_value = array_ops.constant([], dtype=dtypes.float32)
cluster_centers = variable_scope.variable(init_value, cluster_centers = variable_scope.variable(
name='clusters', init_value, name='clusters', validate_shape=False)
validate_shape=False) cluster_centers_initialized = variable_scope.variable(
cluster_centers_initialized = variable_scope.variable(False, False, dtype=dtypes.bool, name='initialized')
dtype=dtypes.bool,
name='initialized')
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
# Copy of cluster centers actively updated each step according to # Copy of cluster centers actively updated each step according to
# mini-batch update rule. # mini-batch update rule.
cluster_centers_updated = variable_scope.variable(init_value, cluster_centers_updated = variable_scope.variable(
name='clusters_updated', init_value, name='clusters_updated', validate_shape=False)
validate_shape=False)
# How many steps till we copy the updated clusters to cluster_centers. # How many steps till we copy the updated clusters to cluster_centers.
update_in_steps = variable_scope.variable( update_in_steps = variable_scope.variable(
self._mini_batch_steps_per_iteration, self._mini_batch_steps_per_iteration,
@ -347,20 +341,15 @@ class KMeans(object):
name='update_in_steps') name='update_in_steps')
# Count of points assigned to cluster_centers_updated. # Count of points assigned to cluster_centers_updated.
cluster_counts = variable_scope.variable( cluster_counts = variable_scope.variable(
array_ops.zeros([self._num_clusters], array_ops.zeros([self._num_clusters], dtype=dtypes.int64))
dtype=dtypes.int64))
else: else:
cluster_centers_updated = cluster_centers cluster_centers_updated = cluster_centers
update_in_steps = None update_in_steps = None
cluster_counts = (variable_scope.variable(array_ops.ones( cluster_counts = (variable_scope.variable(
[self._num_clusters], array_ops.ones([self._num_clusters], dtype=dtypes.int64))
dtype=dtypes.int64))
if self._use_mini_batch else None) if self._use_mini_batch else None)
return (cluster_centers, return (cluster_centers, cluster_centers_initialized, cluster_counts,
cluster_centers_initialized, cluster_centers_updated, update_in_steps)
cluster_counts,
cluster_centers_updated,
update_in_steps)
@classmethod @classmethod
def _l2_normalize_data(cls, inputs): def _l2_normalize_data(cls, inputs):
@ -391,11 +380,8 @@ class KMeans(object):
""" """
# Implementation of kmeans. # Implementation of kmeans.
inputs = self._inputs inputs = self._inputs
(cluster_centers_var, (cluster_centers_var, cluster_centers_initialized, total_counts,
cluster_centers_initialized, cluster_centers_updated, update_in_steps) = self._create_variables()
total_counts,
cluster_centers_updated,
update_in_steps) = self._create_variables()
init_op = self._initialize_clusters(cluster_centers_var, init_op = self._initialize_clusters(cluster_centers_var,
cluster_centers_initialized, cluster_centers_initialized,
cluster_centers_updated) cluster_centers_updated)
@ -409,8 +395,7 @@ class KMeans(object):
all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers) all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
if self._use_mini_batch: if self._use_mini_batch:
sync_updates_op = self._mini_batch_sync_updates_op( sync_updates_op = self._mini_batch_sync_updates_op(
update_in_steps, update_in_steps, cluster_centers_var, cluster_centers_updated,
cluster_centers_var, cluster_centers_updated,
total_counts) total_counts)
assert sync_updates_op is not None assert sync_updates_op is not None
with ops.control_dependencies([sync_updates_op]): with ops.control_dependencies([sync_updates_op]):
@ -421,15 +406,15 @@ class KMeans(object):
training_op = self._full_batch_training_op(inputs, cluster_idx, training_op = self._full_batch_training_op(inputs, cluster_idx,
cluster_centers_var) cluster_centers_var)
return (all_scores, cluster_idx, scores, return (all_scores, cluster_idx, scores, cluster_centers_initialized,
cluster_centers_initialized, init_op, training_op) init_op, training_op)
def _mini_batch_sync_updates_op(self, update_in_steps, def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
cluster_centers_var, cluster_centers_updated, cluster_centers_updated, total_counts):
total_counts):
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1: if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
assert update_in_steps is not None assert update_in_steps is not None
with ops.colocate_with(update_in_steps): with ops.colocate_with(update_in_steps):
def _f(): def _f():
# Note that there is a race condition here, so we do a best effort # Note that there is a race condition here, so we do a best effort
# updates here. We reset update_in_steps first so that other workers # updates here. We reset update_in_steps first so that other workers
@ -437,33 +422,36 @@ class KMeans(object):
# before resetting total_counts to avoid large updates to # before resetting total_counts to avoid large updates to
# cluster_centers_updated based on partially updated # cluster_centers_updated based on partially updated
# cluster_center_vars. # cluster_center_vars.
with ops.control_dependencies([state_ops.assign( with ops.control_dependencies([
update_in_steps, state_ops.assign(update_in_steps,
self._mini_batch_steps_per_iteration - 1)]): self._mini_batch_steps_per_iteration - 1)
with ops.colocate_with(cluster_centers_updated): ]):
with ops.colocate_with(
cluster_centers_updated, ignore_existing=True):
if self._distance_metric == COSINE_DISTANCE: if self._distance_metric == COSINE_DISTANCE:
cluster_centers = nn_impl.l2_normalize(cluster_centers_updated, cluster_centers = nn_impl.l2_normalize(
dim=1) cluster_centers_updated, dim=1)
else: else:
cluster_centers = cluster_centers_updated cluster_centers = cluster_centers_updated
with ops.colocate_with(cluster_centers_var): with ops.colocate_with(cluster_centers_var):
with ops.control_dependencies([state_ops.assign( with ops.control_dependencies(
cluster_centers_var, [state_ops.assign(cluster_centers_var, cluster_centers)]):
cluster_centers)]): with ops.colocate_with(
with ops.colocate_with(cluster_centers_var): cluster_centers_var, ignore_existing=True):
with ops.control_dependencies([ with ops.control_dependencies([
state_ops.assign(total_counts, state_ops.assign(total_counts,
array_ops.zeros_like(total_counts))]): array_ops.zeros_like(total_counts))
]):
return array_ops.identity(update_in_steps) return array_ops.identity(update_in_steps)
return control_flow_ops.cond( return control_flow_ops.cond(
update_in_steps <= 0, update_in_steps <= 0, _f,
_f,
lambda: state_ops.assign_sub(update_in_steps, 1)) lambda: state_ops.assign_sub(update_in_steps, 1))
else: else:
return control_flow_ops.no_op() return control_flow_ops.no_op()
def _mini_batch_training_op(self, inputs, cluster_idx_list, def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
cluster_centers, total_counts): total_counts):
"""Creates an op for training for mini batch case. """Creates an op for training for mini batch case.
Args: Args:
@ -487,17 +475,15 @@ class KMeans(object):
unique_ids, unique_idx = array_ops.unique(cluster_idx) unique_ids, unique_idx = array_ops.unique(cluster_idx)
num_unique_cluster_idx = array_ops.size(unique_ids) num_unique_cluster_idx = array_ops.size(unique_ids)
# Fetch the old values of counts and cluster_centers. # Fetch the old values of counts and cluster_centers.
with ops.colocate_with(total_counts): with ops.colocate_with(total_counts, ignore_existing=True):
old_counts = array_ops.gather(total_counts, unique_ids) old_counts = array_ops.gather(total_counts, unique_ids)
# TODO(agarwal): This colocation seems to run into problems. Fix it. # TODO(agarwal): This colocation seems to run into problems. Fix it.
# with ops.colocate_with(cluster_centers): with ops.colocate_with(cluster_centers, ignore_existing=True):
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids) old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
# Locally aggregate the increment to counts. # Locally aggregate the increment to counts.
count_updates = math_ops.unsorted_segment_sum( count_updates = math_ops.unsorted_segment_sum(
array_ops.ones_like( array_ops.ones_like(unique_idx, dtype=total_counts.dtype),
unique_idx, dtype=total_counts.dtype), unique_idx, num_unique_cluster_idx)
unique_idx,
num_unique_cluster_idx)
# Locally compute the sum of inputs mapped to each id. # Locally compute the sum of inputs mapped to each id.
# For a cluster with old cluster value x, old count n, and with data # For a cluster with old cluster value x, old count n, and with data
# d_1,...d_k newly assigned to it, we recompute the new value as # d_1,...d_k newly assigned to it, we recompute the new value as
@ -507,13 +493,12 @@ class KMeans(object):
inp, unique_idx, num_unique_cluster_idx) inp, unique_idx, num_unique_cluster_idx)
# Shape to enable broadcasting count_updates and learning_rate to inp. # Shape to enable broadcasting count_updates and learning_rate to inp.
# It extends the shape with 1's to match the rank of inp. # It extends the shape with 1's to match the rank of inp.
broadcast_shape = array_ops.concat( broadcast_shape = array_ops.concat([
[ array_ops.reshape(num_unique_cluster_idx, [1]),
array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones( array_ops.ones(
array_ops.reshape(array_ops.rank(inp) - 1, [1]), array_ops.reshape(array_ops.rank(inp) - 1, [1]),
dtype=dtypes.int32) dtype=dtypes.int32)
], ], 0)
0)
# Subtract k * x, see comment above. # Subtract k * x, see comment above.
cluster_center_updates -= math_ops.cast( cluster_center_updates -= math_ops.cast(
array_ops.reshape(count_updates, broadcast_shape), array_ops.reshape(count_updates, broadcast_shape),
@ -524,14 +509,10 @@ class KMeans(object):
# scale by 1 / (n + k), see comment above. # scale by 1 / (n + k), see comment above.
cluster_center_updates *= learning_rate cluster_center_updates *= learning_rate
# Apply the updates. # Apply the updates.
update_counts = state_ops.scatter_add( update_counts = state_ops.scatter_add(total_counts, unique_ids,
total_counts, count_updates)
unique_ids,
count_updates)
update_cluster_centers = state_ops.scatter_add( update_cluster_centers = state_ops.scatter_add(
cluster_centers, cluster_centers, unique_ids, cluster_center_updates)
unique_ids,
cluster_center_updates)
update_ops.extend([update_counts, update_cluster_centers]) update_ops.extend([update_counts, update_cluster_centers])
return control_flow_ops.group(*update_ops) return control_flow_ops.group(*update_ops)
@ -552,7 +533,7 @@ class KMeans(object):
cluster_counts = [] cluster_counts = []
epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype) epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
for inp, cluster_idx in zip(inputs, cluster_idx_list): for inp, cluster_idx in zip(inputs, cluster_idx_list):
with ops.colocate_with(inp): with ops.colocate_with(inp, ignore_existing=True):
cluster_sums.append( cluster_sums.append(
math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters)) math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters))
cluster_counts.append( cluster_counts.append(
@ -561,7 +542,7 @@ class KMeans(object):
array_ops.ones( array_ops.ones(
array_ops.reshape(array_ops.shape(inp)[0], [-1])), array_ops.reshape(array_ops.shape(inp)[0], [-1])),
[-1, 1]), cluster_idx, self._num_clusters)) [-1, 1]), cluster_idx, self._num_clusters))
with ops.colocate_with(cluster_centers): with ops.colocate_with(cluster_centers, ignore_existing=True):
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast( new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon) math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
if self._clusters_l2_normalized(): if self._clusters_l2_normalized():