mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Fix device colocation for KMeans in case of multiple parameter servers.
PiperOrigin-RevId: 157795360
This commit is contained in:
parent
b659bc39f2
commit
07710014d2
|
|
@ -164,11 +164,12 @@ class KMeans(object):
|
|||
with ops.colocate_with(inp):
|
||||
# Computes Euclidean distance. Note the first and third terms are
|
||||
# broadcast additions.
|
||||
squared_distance = (math_ops.reduce_sum(
|
||||
math_ops.square(inp), 1, keep_dims=True) - 2 * math_ops.matmul(
|
||||
inp, clusters, transpose_b=True) + array_ops.transpose(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.square(clusters), 1, keep_dims=True)))
|
||||
squared_distance = (
|
||||
math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) -
|
||||
2 * math_ops.matmul(inp, clusters, transpose_b=True) +
|
||||
array_ops.transpose(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.square(clusters), 1, keep_dims=True)))
|
||||
output.append(squared_distance)
|
||||
|
||||
return output
|
||||
|
|
@ -229,12 +230,12 @@ class KMeans(object):
|
|||
clusters = nn_impl.l2_normalize(clusters, dim=1)
|
||||
for inp, score in zip(inputs, scores):
|
||||
with ops.colocate_with(inp):
|
||||
(indices,
|
||||
distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1)
|
||||
(indices, distances) = gen_clustering_ops.nearest_neighbors(
|
||||
inp, clusters, 1)
|
||||
if self._distance_metric == COSINE_DISTANCE:
|
||||
distances *= 0.5
|
||||
output.append(
|
||||
(score, array_ops.squeeze(distances), array_ops.squeeze(indices)))
|
||||
output.append((score, array_ops.squeeze(distances),
|
||||
array_ops.squeeze(indices)))
|
||||
return zip(*output)
|
||||
|
||||
def _init_clusters_random(self):
|
||||
|
|
@ -265,9 +266,7 @@ class KMeans(object):
|
|||
(not self._use_mini_batch or
|
||||
self._mini_batch_steps_per_iteration > 1))
|
||||
|
||||
def _initialize_clusters(self,
|
||||
cluster_centers,
|
||||
cluster_centers_initialized,
|
||||
def _initialize_clusters(self, cluster_centers, cluster_centers_initialized,
|
||||
cluster_centers_updated):
|
||||
"""Returns an op to initialize the cluster centers."""
|
||||
|
||||
|
|
@ -294,22 +293,20 @@ class KMeans(object):
|
|||
|
||||
with ops.colocate_with(cluster_centers_initialized):
|
||||
initialized = control_flow_ops.with_dependencies(
|
||||
[clusters_init],
|
||||
array_ops.identity(cluster_centers_initialized))
|
||||
[clusters_init], array_ops.identity(cluster_centers_initialized))
|
||||
with ops.colocate_with(cluster_centers):
|
||||
assign_centers = state_ops.assign(cluster_centers, clusters_init,
|
||||
validate_shape=False)
|
||||
assign_centers = state_ops.assign(
|
||||
cluster_centers, clusters_init, validate_shape=False)
|
||||
if cluster_centers_updated != cluster_centers:
|
||||
assign_centers = control_flow_ops.group(
|
||||
assign_centers,
|
||||
state_ops.assign(cluster_centers_updated, clusters_init,
|
||||
validate_shape=False))
|
||||
assign_centers = control_flow_ops.with_dependencies(
|
||||
[assign_centers],
|
||||
state_ops.assign(cluster_centers_initialized, True))
|
||||
return control_flow_ops.cond(initialized,
|
||||
control_flow_ops.no_op,
|
||||
lambda: assign_centers).op
|
||||
assign_centers = control_flow_ops.group(assign_centers,
|
||||
state_ops.assign(
|
||||
cluster_centers_updated,
|
||||
clusters_init,
|
||||
validate_shape=False))
|
||||
assign_centers = control_flow_ops.with_dependencies(
|
||||
[assign_centers], state_ops.assign(cluster_centers_initialized, True))
|
||||
return control_flow_ops.cond(initialized, control_flow_ops.no_op,
|
||||
lambda: assign_centers).op
|
||||
|
||||
def _create_variables(self):
|
||||
"""Creates variables.
|
||||
|
|
@ -327,19 +324,16 @@ class KMeans(object):
|
|||
cluster_centers_updated back to cluster_centers.
|
||||
"""
|
||||
init_value = array_ops.constant([], dtype=dtypes.float32)
|
||||
cluster_centers = variable_scope.variable(init_value,
|
||||
name='clusters',
|
||||
validate_shape=False)
|
||||
cluster_centers_initialized = variable_scope.variable(False,
|
||||
dtype=dtypes.bool,
|
||||
name='initialized')
|
||||
cluster_centers = variable_scope.variable(
|
||||
init_value, name='clusters', validate_shape=False)
|
||||
cluster_centers_initialized = variable_scope.variable(
|
||||
False, dtype=dtypes.bool, name='initialized')
|
||||
|
||||
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
|
||||
# Copy of cluster centers actively updated each step according to
|
||||
# mini-batch update rule.
|
||||
cluster_centers_updated = variable_scope.variable(init_value,
|
||||
name='clusters_updated',
|
||||
validate_shape=False)
|
||||
cluster_centers_updated = variable_scope.variable(
|
||||
init_value, name='clusters_updated', validate_shape=False)
|
||||
# How many steps till we copy the updated clusters to cluster_centers.
|
||||
update_in_steps = variable_scope.variable(
|
||||
self._mini_batch_steps_per_iteration,
|
||||
|
|
@ -347,20 +341,15 @@ class KMeans(object):
|
|||
name='update_in_steps')
|
||||
# Count of points assigned to cluster_centers_updated.
|
||||
cluster_counts = variable_scope.variable(
|
||||
array_ops.zeros([self._num_clusters],
|
||||
dtype=dtypes.int64))
|
||||
array_ops.zeros([self._num_clusters], dtype=dtypes.int64))
|
||||
else:
|
||||
cluster_centers_updated = cluster_centers
|
||||
update_in_steps = None
|
||||
cluster_counts = (variable_scope.variable(array_ops.ones(
|
||||
[self._num_clusters],
|
||||
dtype=dtypes.int64))
|
||||
cluster_counts = (variable_scope.variable(
|
||||
array_ops.ones([self._num_clusters], dtype=dtypes.int64))
|
||||
if self._use_mini_batch else None)
|
||||
return (cluster_centers,
|
||||
cluster_centers_initialized,
|
||||
cluster_counts,
|
||||
cluster_centers_updated,
|
||||
update_in_steps)
|
||||
return (cluster_centers, cluster_centers_initialized, cluster_counts,
|
||||
cluster_centers_updated, update_in_steps)
|
||||
|
||||
@classmethod
|
||||
def _l2_normalize_data(cls, inputs):
|
||||
|
|
@ -391,11 +380,8 @@ class KMeans(object):
|
|||
"""
|
||||
# Implementation of kmeans.
|
||||
inputs = self._inputs
|
||||
(cluster_centers_var,
|
||||
cluster_centers_initialized,
|
||||
total_counts,
|
||||
cluster_centers_updated,
|
||||
update_in_steps) = self._create_variables()
|
||||
(cluster_centers_var, cluster_centers_initialized, total_counts,
|
||||
cluster_centers_updated, update_in_steps) = self._create_variables()
|
||||
init_op = self._initialize_clusters(cluster_centers_var,
|
||||
cluster_centers_initialized,
|
||||
cluster_centers_updated)
|
||||
|
|
@ -409,8 +395,7 @@ class KMeans(object):
|
|||
all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
|
||||
if self._use_mini_batch:
|
||||
sync_updates_op = self._mini_batch_sync_updates_op(
|
||||
update_in_steps,
|
||||
cluster_centers_var, cluster_centers_updated,
|
||||
update_in_steps, cluster_centers_var, cluster_centers_updated,
|
||||
total_counts)
|
||||
assert sync_updates_op is not None
|
||||
with ops.control_dependencies([sync_updates_op]):
|
||||
|
|
@ -421,15 +406,15 @@ class KMeans(object):
|
|||
training_op = self._full_batch_training_op(inputs, cluster_idx,
|
||||
cluster_centers_var)
|
||||
|
||||
return (all_scores, cluster_idx, scores,
|
||||
cluster_centers_initialized, init_op, training_op)
|
||||
return (all_scores, cluster_idx, scores, cluster_centers_initialized,
|
||||
init_op, training_op)
|
||||
|
||||
def _mini_batch_sync_updates_op(self, update_in_steps,
|
||||
cluster_centers_var, cluster_centers_updated,
|
||||
total_counts):
|
||||
def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
|
||||
cluster_centers_updated, total_counts):
|
||||
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
|
||||
assert update_in_steps is not None
|
||||
with ops.colocate_with(update_in_steps):
|
||||
|
||||
def _f():
|
||||
# Note that there is a race condition here, so we do a best effort
|
||||
# updates here. We reset update_in_steps first so that other workers
|
||||
|
|
@ -437,33 +422,36 @@ class KMeans(object):
|
|||
# before resetting total_counts to avoid large updates to
|
||||
# cluster_centers_updated based on partially updated
|
||||
# cluster_center_vars.
|
||||
with ops.control_dependencies([state_ops.assign(
|
||||
update_in_steps,
|
||||
self._mini_batch_steps_per_iteration - 1)]):
|
||||
with ops.colocate_with(cluster_centers_updated):
|
||||
with ops.control_dependencies([
|
||||
state_ops.assign(update_in_steps,
|
||||
self._mini_batch_steps_per_iteration - 1)
|
||||
]):
|
||||
with ops.colocate_with(
|
||||
cluster_centers_updated, ignore_existing=True):
|
||||
if self._distance_metric == COSINE_DISTANCE:
|
||||
cluster_centers = nn_impl.l2_normalize(cluster_centers_updated,
|
||||
dim=1)
|
||||
cluster_centers = nn_impl.l2_normalize(
|
||||
cluster_centers_updated, dim=1)
|
||||
else:
|
||||
cluster_centers = cluster_centers_updated
|
||||
with ops.colocate_with(cluster_centers_var):
|
||||
with ops.control_dependencies([state_ops.assign(
|
||||
cluster_centers_var,
|
||||
cluster_centers)]):
|
||||
with ops.colocate_with(cluster_centers_var):
|
||||
with ops.control_dependencies(
|
||||
[state_ops.assign(cluster_centers_var, cluster_centers)]):
|
||||
with ops.colocate_with(
|
||||
cluster_centers_var, ignore_existing=True):
|
||||
with ops.control_dependencies([
|
||||
state_ops.assign(total_counts,
|
||||
array_ops.zeros_like(total_counts))]):
|
||||
array_ops.zeros_like(total_counts))
|
||||
]):
|
||||
return array_ops.identity(update_in_steps)
|
||||
|
||||
return control_flow_ops.cond(
|
||||
update_in_steps <= 0,
|
||||
_f,
|
||||
update_in_steps <= 0, _f,
|
||||
lambda: state_ops.assign_sub(update_in_steps, 1))
|
||||
else:
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
def _mini_batch_training_op(self, inputs, cluster_idx_list,
|
||||
cluster_centers, total_counts):
|
||||
def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
|
||||
total_counts):
|
||||
"""Creates an op for training for mini batch case.
|
||||
|
||||
Args:
|
||||
|
|
@ -487,17 +475,15 @@ class KMeans(object):
|
|||
unique_ids, unique_idx = array_ops.unique(cluster_idx)
|
||||
num_unique_cluster_idx = array_ops.size(unique_ids)
|
||||
# Fetch the old values of counts and cluster_centers.
|
||||
with ops.colocate_with(total_counts):
|
||||
with ops.colocate_with(total_counts, ignore_existing=True):
|
||||
old_counts = array_ops.gather(total_counts, unique_ids)
|
||||
# TODO(agarwal): This colocation seems to run into problems. Fix it.
|
||||
# with ops.colocate_with(cluster_centers):
|
||||
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
|
||||
with ops.colocate_with(cluster_centers, ignore_existing=True):
|
||||
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
|
||||
# Locally aggregate the increment to counts.
|
||||
count_updates = math_ops.unsorted_segment_sum(
|
||||
array_ops.ones_like(
|
||||
unique_idx, dtype=total_counts.dtype),
|
||||
unique_idx,
|
||||
num_unique_cluster_idx)
|
||||
array_ops.ones_like(unique_idx, dtype=total_counts.dtype),
|
||||
unique_idx, num_unique_cluster_idx)
|
||||
# Locally compute the sum of inputs mapped to each id.
|
||||
# For a cluster with old cluster value x, old count n, and with data
|
||||
# d_1,...d_k newly assigned to it, we recompute the new value as
|
||||
|
|
@ -507,13 +493,12 @@ class KMeans(object):
|
|||
inp, unique_idx, num_unique_cluster_idx)
|
||||
# Shape to enable broadcasting count_updates and learning_rate to inp.
|
||||
# It extends the shape with 1's to match the rank of inp.
|
||||
broadcast_shape = array_ops.concat(
|
||||
[
|
||||
array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones(
|
||||
array_ops.reshape(array_ops.rank(inp) - 1, [1]),
|
||||
dtype=dtypes.int32)
|
||||
],
|
||||
0)
|
||||
broadcast_shape = array_ops.concat([
|
||||
array_ops.reshape(num_unique_cluster_idx, [1]),
|
||||
array_ops.ones(
|
||||
array_ops.reshape(array_ops.rank(inp) - 1, [1]),
|
||||
dtype=dtypes.int32)
|
||||
], 0)
|
||||
# Subtract k * x, see comment above.
|
||||
cluster_center_updates -= math_ops.cast(
|
||||
array_ops.reshape(count_updates, broadcast_shape),
|
||||
|
|
@ -524,14 +509,10 @@ class KMeans(object):
|
|||
# scale by 1 / (n + k), see comment above.
|
||||
cluster_center_updates *= learning_rate
|
||||
# Apply the updates.
|
||||
update_counts = state_ops.scatter_add(
|
||||
total_counts,
|
||||
unique_ids,
|
||||
count_updates)
|
||||
update_counts = state_ops.scatter_add(total_counts, unique_ids,
|
||||
count_updates)
|
||||
update_cluster_centers = state_ops.scatter_add(
|
||||
cluster_centers,
|
||||
unique_ids,
|
||||
cluster_center_updates)
|
||||
cluster_centers, unique_ids, cluster_center_updates)
|
||||
update_ops.extend([update_counts, update_cluster_centers])
|
||||
return control_flow_ops.group(*update_ops)
|
||||
|
||||
|
|
@ -552,7 +533,7 @@ class KMeans(object):
|
|||
cluster_counts = []
|
||||
epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
|
||||
for inp, cluster_idx in zip(inputs, cluster_idx_list):
|
||||
with ops.colocate_with(inp):
|
||||
with ops.colocate_with(inp, ignore_existing=True):
|
||||
cluster_sums.append(
|
||||
math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters))
|
||||
cluster_counts.append(
|
||||
|
|
@ -561,7 +542,7 @@ class KMeans(object):
|
|||
array_ops.ones(
|
||||
array_ops.reshape(array_ops.shape(inp)[0], [-1])),
|
||||
[-1, 1]), cluster_idx, self._num_clusters))
|
||||
with ops.colocate_with(cluster_centers):
|
||||
with ops.colocate_with(cluster_centers, ignore_existing=True):
|
||||
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
|
||||
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
|
||||
if self._clusters_l2_normalized():
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user