mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Fix device colocation for KMeans in case of multiple parameter servers.
PiperOrigin-RevId: 157795360
This commit is contained in:
parent
b659bc39f2
commit
07710014d2
|
|
@ -164,11 +164,12 @@ class KMeans(object):
|
||||||
with ops.colocate_with(inp):
|
with ops.colocate_with(inp):
|
||||||
# Computes Euclidean distance. Note the first and third terms are
|
# Computes Euclidean distance. Note the first and third terms are
|
||||||
# broadcast additions.
|
# broadcast additions.
|
||||||
squared_distance = (math_ops.reduce_sum(
|
squared_distance = (
|
||||||
math_ops.square(inp), 1, keep_dims=True) - 2 * math_ops.matmul(
|
math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) -
|
||||||
inp, clusters, transpose_b=True) + array_ops.transpose(
|
2 * math_ops.matmul(inp, clusters, transpose_b=True) +
|
||||||
math_ops.reduce_sum(
|
array_ops.transpose(
|
||||||
math_ops.square(clusters), 1, keep_dims=True)))
|
math_ops.reduce_sum(
|
||||||
|
math_ops.square(clusters), 1, keep_dims=True)))
|
||||||
output.append(squared_distance)
|
output.append(squared_distance)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
@ -229,12 +230,12 @@ class KMeans(object):
|
||||||
clusters = nn_impl.l2_normalize(clusters, dim=1)
|
clusters = nn_impl.l2_normalize(clusters, dim=1)
|
||||||
for inp, score in zip(inputs, scores):
|
for inp, score in zip(inputs, scores):
|
||||||
with ops.colocate_with(inp):
|
with ops.colocate_with(inp):
|
||||||
(indices,
|
(indices, distances) = gen_clustering_ops.nearest_neighbors(
|
||||||
distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1)
|
inp, clusters, 1)
|
||||||
if self._distance_metric == COSINE_DISTANCE:
|
if self._distance_metric == COSINE_DISTANCE:
|
||||||
distances *= 0.5
|
distances *= 0.5
|
||||||
output.append(
|
output.append((score, array_ops.squeeze(distances),
|
||||||
(score, array_ops.squeeze(distances), array_ops.squeeze(indices)))
|
array_ops.squeeze(indices)))
|
||||||
return zip(*output)
|
return zip(*output)
|
||||||
|
|
||||||
def _init_clusters_random(self):
|
def _init_clusters_random(self):
|
||||||
|
|
@ -265,9 +266,7 @@ class KMeans(object):
|
||||||
(not self._use_mini_batch or
|
(not self._use_mini_batch or
|
||||||
self._mini_batch_steps_per_iteration > 1))
|
self._mini_batch_steps_per_iteration > 1))
|
||||||
|
|
||||||
def _initialize_clusters(self,
|
def _initialize_clusters(self, cluster_centers, cluster_centers_initialized,
|
||||||
cluster_centers,
|
|
||||||
cluster_centers_initialized,
|
|
||||||
cluster_centers_updated):
|
cluster_centers_updated):
|
||||||
"""Returns an op to initialize the cluster centers."""
|
"""Returns an op to initialize the cluster centers."""
|
||||||
|
|
||||||
|
|
@ -294,22 +293,20 @@ class KMeans(object):
|
||||||
|
|
||||||
with ops.colocate_with(cluster_centers_initialized):
|
with ops.colocate_with(cluster_centers_initialized):
|
||||||
initialized = control_flow_ops.with_dependencies(
|
initialized = control_flow_ops.with_dependencies(
|
||||||
[clusters_init],
|
[clusters_init], array_ops.identity(cluster_centers_initialized))
|
||||||
array_ops.identity(cluster_centers_initialized))
|
|
||||||
with ops.colocate_with(cluster_centers):
|
with ops.colocate_with(cluster_centers):
|
||||||
assign_centers = state_ops.assign(cluster_centers, clusters_init,
|
assign_centers = state_ops.assign(
|
||||||
validate_shape=False)
|
cluster_centers, clusters_init, validate_shape=False)
|
||||||
if cluster_centers_updated != cluster_centers:
|
if cluster_centers_updated != cluster_centers:
|
||||||
assign_centers = control_flow_ops.group(
|
assign_centers = control_flow_ops.group(assign_centers,
|
||||||
assign_centers,
|
state_ops.assign(
|
||||||
state_ops.assign(cluster_centers_updated, clusters_init,
|
cluster_centers_updated,
|
||||||
validate_shape=False))
|
clusters_init,
|
||||||
assign_centers = control_flow_ops.with_dependencies(
|
validate_shape=False))
|
||||||
[assign_centers],
|
assign_centers = control_flow_ops.with_dependencies(
|
||||||
state_ops.assign(cluster_centers_initialized, True))
|
[assign_centers], state_ops.assign(cluster_centers_initialized, True))
|
||||||
return control_flow_ops.cond(initialized,
|
return control_flow_ops.cond(initialized, control_flow_ops.no_op,
|
||||||
control_flow_ops.no_op,
|
lambda: assign_centers).op
|
||||||
lambda: assign_centers).op
|
|
||||||
|
|
||||||
def _create_variables(self):
|
def _create_variables(self):
|
||||||
"""Creates variables.
|
"""Creates variables.
|
||||||
|
|
@ -327,19 +324,16 @@ class KMeans(object):
|
||||||
cluster_centers_updated back to cluster_centers.
|
cluster_centers_updated back to cluster_centers.
|
||||||
"""
|
"""
|
||||||
init_value = array_ops.constant([], dtype=dtypes.float32)
|
init_value = array_ops.constant([], dtype=dtypes.float32)
|
||||||
cluster_centers = variable_scope.variable(init_value,
|
cluster_centers = variable_scope.variable(
|
||||||
name='clusters',
|
init_value, name='clusters', validate_shape=False)
|
||||||
validate_shape=False)
|
cluster_centers_initialized = variable_scope.variable(
|
||||||
cluster_centers_initialized = variable_scope.variable(False,
|
False, dtype=dtypes.bool, name='initialized')
|
||||||
dtype=dtypes.bool,
|
|
||||||
name='initialized')
|
|
||||||
|
|
||||||
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
|
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
|
||||||
# Copy of cluster centers actively updated each step according to
|
# Copy of cluster centers actively updated each step according to
|
||||||
# mini-batch update rule.
|
# mini-batch update rule.
|
||||||
cluster_centers_updated = variable_scope.variable(init_value,
|
cluster_centers_updated = variable_scope.variable(
|
||||||
name='clusters_updated',
|
init_value, name='clusters_updated', validate_shape=False)
|
||||||
validate_shape=False)
|
|
||||||
# How many steps till we copy the updated clusters to cluster_centers.
|
# How many steps till we copy the updated clusters to cluster_centers.
|
||||||
update_in_steps = variable_scope.variable(
|
update_in_steps = variable_scope.variable(
|
||||||
self._mini_batch_steps_per_iteration,
|
self._mini_batch_steps_per_iteration,
|
||||||
|
|
@ -347,20 +341,15 @@ class KMeans(object):
|
||||||
name='update_in_steps')
|
name='update_in_steps')
|
||||||
# Count of points assigned to cluster_centers_updated.
|
# Count of points assigned to cluster_centers_updated.
|
||||||
cluster_counts = variable_scope.variable(
|
cluster_counts = variable_scope.variable(
|
||||||
array_ops.zeros([self._num_clusters],
|
array_ops.zeros([self._num_clusters], dtype=dtypes.int64))
|
||||||
dtype=dtypes.int64))
|
|
||||||
else:
|
else:
|
||||||
cluster_centers_updated = cluster_centers
|
cluster_centers_updated = cluster_centers
|
||||||
update_in_steps = None
|
update_in_steps = None
|
||||||
cluster_counts = (variable_scope.variable(array_ops.ones(
|
cluster_counts = (variable_scope.variable(
|
||||||
[self._num_clusters],
|
array_ops.ones([self._num_clusters], dtype=dtypes.int64))
|
||||||
dtype=dtypes.int64))
|
|
||||||
if self._use_mini_batch else None)
|
if self._use_mini_batch else None)
|
||||||
return (cluster_centers,
|
return (cluster_centers, cluster_centers_initialized, cluster_counts,
|
||||||
cluster_centers_initialized,
|
cluster_centers_updated, update_in_steps)
|
||||||
cluster_counts,
|
|
||||||
cluster_centers_updated,
|
|
||||||
update_in_steps)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _l2_normalize_data(cls, inputs):
|
def _l2_normalize_data(cls, inputs):
|
||||||
|
|
@ -391,11 +380,8 @@ class KMeans(object):
|
||||||
"""
|
"""
|
||||||
# Implementation of kmeans.
|
# Implementation of kmeans.
|
||||||
inputs = self._inputs
|
inputs = self._inputs
|
||||||
(cluster_centers_var,
|
(cluster_centers_var, cluster_centers_initialized, total_counts,
|
||||||
cluster_centers_initialized,
|
cluster_centers_updated, update_in_steps) = self._create_variables()
|
||||||
total_counts,
|
|
||||||
cluster_centers_updated,
|
|
||||||
update_in_steps) = self._create_variables()
|
|
||||||
init_op = self._initialize_clusters(cluster_centers_var,
|
init_op = self._initialize_clusters(cluster_centers_var,
|
||||||
cluster_centers_initialized,
|
cluster_centers_initialized,
|
||||||
cluster_centers_updated)
|
cluster_centers_updated)
|
||||||
|
|
@ -409,8 +395,7 @@ class KMeans(object):
|
||||||
all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
|
all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
|
||||||
if self._use_mini_batch:
|
if self._use_mini_batch:
|
||||||
sync_updates_op = self._mini_batch_sync_updates_op(
|
sync_updates_op = self._mini_batch_sync_updates_op(
|
||||||
update_in_steps,
|
update_in_steps, cluster_centers_var, cluster_centers_updated,
|
||||||
cluster_centers_var, cluster_centers_updated,
|
|
||||||
total_counts)
|
total_counts)
|
||||||
assert sync_updates_op is not None
|
assert sync_updates_op is not None
|
||||||
with ops.control_dependencies([sync_updates_op]):
|
with ops.control_dependencies([sync_updates_op]):
|
||||||
|
|
@ -421,15 +406,15 @@ class KMeans(object):
|
||||||
training_op = self._full_batch_training_op(inputs, cluster_idx,
|
training_op = self._full_batch_training_op(inputs, cluster_idx,
|
||||||
cluster_centers_var)
|
cluster_centers_var)
|
||||||
|
|
||||||
return (all_scores, cluster_idx, scores,
|
return (all_scores, cluster_idx, scores, cluster_centers_initialized,
|
||||||
cluster_centers_initialized, init_op, training_op)
|
init_op, training_op)
|
||||||
|
|
||||||
def _mini_batch_sync_updates_op(self, update_in_steps,
|
def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
|
||||||
cluster_centers_var, cluster_centers_updated,
|
cluster_centers_updated, total_counts):
|
||||||
total_counts):
|
|
||||||
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
|
if self._use_mini_batch and self._mini_batch_steps_per_iteration > 1:
|
||||||
assert update_in_steps is not None
|
assert update_in_steps is not None
|
||||||
with ops.colocate_with(update_in_steps):
|
with ops.colocate_with(update_in_steps):
|
||||||
|
|
||||||
def _f():
|
def _f():
|
||||||
# Note that there is a race condition here, so we do a best effort
|
# Note that there is a race condition here, so we do a best effort
|
||||||
# updates here. We reset update_in_steps first so that other workers
|
# updates here. We reset update_in_steps first so that other workers
|
||||||
|
|
@ -437,33 +422,36 @@ class KMeans(object):
|
||||||
# before resetting total_counts to avoid large updates to
|
# before resetting total_counts to avoid large updates to
|
||||||
# cluster_centers_updated based on partially updated
|
# cluster_centers_updated based on partially updated
|
||||||
# cluster_center_vars.
|
# cluster_center_vars.
|
||||||
with ops.control_dependencies([state_ops.assign(
|
with ops.control_dependencies([
|
||||||
update_in_steps,
|
state_ops.assign(update_in_steps,
|
||||||
self._mini_batch_steps_per_iteration - 1)]):
|
self._mini_batch_steps_per_iteration - 1)
|
||||||
with ops.colocate_with(cluster_centers_updated):
|
]):
|
||||||
|
with ops.colocate_with(
|
||||||
|
cluster_centers_updated, ignore_existing=True):
|
||||||
if self._distance_metric == COSINE_DISTANCE:
|
if self._distance_metric == COSINE_DISTANCE:
|
||||||
cluster_centers = nn_impl.l2_normalize(cluster_centers_updated,
|
cluster_centers = nn_impl.l2_normalize(
|
||||||
dim=1)
|
cluster_centers_updated, dim=1)
|
||||||
else:
|
else:
|
||||||
cluster_centers = cluster_centers_updated
|
cluster_centers = cluster_centers_updated
|
||||||
with ops.colocate_with(cluster_centers_var):
|
with ops.colocate_with(cluster_centers_var):
|
||||||
with ops.control_dependencies([state_ops.assign(
|
with ops.control_dependencies(
|
||||||
cluster_centers_var,
|
[state_ops.assign(cluster_centers_var, cluster_centers)]):
|
||||||
cluster_centers)]):
|
with ops.colocate_with(
|
||||||
with ops.colocate_with(cluster_centers_var):
|
cluster_centers_var, ignore_existing=True):
|
||||||
with ops.control_dependencies([
|
with ops.control_dependencies([
|
||||||
state_ops.assign(total_counts,
|
state_ops.assign(total_counts,
|
||||||
array_ops.zeros_like(total_counts))]):
|
array_ops.zeros_like(total_counts))
|
||||||
|
]):
|
||||||
return array_ops.identity(update_in_steps)
|
return array_ops.identity(update_in_steps)
|
||||||
|
|
||||||
return control_flow_ops.cond(
|
return control_flow_ops.cond(
|
||||||
update_in_steps <= 0,
|
update_in_steps <= 0, _f,
|
||||||
_f,
|
|
||||||
lambda: state_ops.assign_sub(update_in_steps, 1))
|
lambda: state_ops.assign_sub(update_in_steps, 1))
|
||||||
else:
|
else:
|
||||||
return control_flow_ops.no_op()
|
return control_flow_ops.no_op()
|
||||||
|
|
||||||
def _mini_batch_training_op(self, inputs, cluster_idx_list,
|
def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
|
||||||
cluster_centers, total_counts):
|
total_counts):
|
||||||
"""Creates an op for training for mini batch case.
|
"""Creates an op for training for mini batch case.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -487,17 +475,15 @@ class KMeans(object):
|
||||||
unique_ids, unique_idx = array_ops.unique(cluster_idx)
|
unique_ids, unique_idx = array_ops.unique(cluster_idx)
|
||||||
num_unique_cluster_idx = array_ops.size(unique_ids)
|
num_unique_cluster_idx = array_ops.size(unique_ids)
|
||||||
# Fetch the old values of counts and cluster_centers.
|
# Fetch the old values of counts and cluster_centers.
|
||||||
with ops.colocate_with(total_counts):
|
with ops.colocate_with(total_counts, ignore_existing=True):
|
||||||
old_counts = array_ops.gather(total_counts, unique_ids)
|
old_counts = array_ops.gather(total_counts, unique_ids)
|
||||||
# TODO(agarwal): This colocation seems to run into problems. Fix it.
|
# TODO(agarwal): This colocation seems to run into problems. Fix it.
|
||||||
# with ops.colocate_with(cluster_centers):
|
with ops.colocate_with(cluster_centers, ignore_existing=True):
|
||||||
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
|
old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
|
||||||
# Locally aggregate the increment to counts.
|
# Locally aggregate the increment to counts.
|
||||||
count_updates = math_ops.unsorted_segment_sum(
|
count_updates = math_ops.unsorted_segment_sum(
|
||||||
array_ops.ones_like(
|
array_ops.ones_like(unique_idx, dtype=total_counts.dtype),
|
||||||
unique_idx, dtype=total_counts.dtype),
|
unique_idx, num_unique_cluster_idx)
|
||||||
unique_idx,
|
|
||||||
num_unique_cluster_idx)
|
|
||||||
# Locally compute the sum of inputs mapped to each id.
|
# Locally compute the sum of inputs mapped to each id.
|
||||||
# For a cluster with old cluster value x, old count n, and with data
|
# For a cluster with old cluster value x, old count n, and with data
|
||||||
# d_1,...d_k newly assigned to it, we recompute the new value as
|
# d_1,...d_k newly assigned to it, we recompute the new value as
|
||||||
|
|
@ -507,13 +493,12 @@ class KMeans(object):
|
||||||
inp, unique_idx, num_unique_cluster_idx)
|
inp, unique_idx, num_unique_cluster_idx)
|
||||||
# Shape to enable broadcasting count_updates and learning_rate to inp.
|
# Shape to enable broadcasting count_updates and learning_rate to inp.
|
||||||
# It extends the shape with 1's to match the rank of inp.
|
# It extends the shape with 1's to match the rank of inp.
|
||||||
broadcast_shape = array_ops.concat(
|
broadcast_shape = array_ops.concat([
|
||||||
[
|
array_ops.reshape(num_unique_cluster_idx, [1]),
|
||||||
array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones(
|
array_ops.ones(
|
||||||
array_ops.reshape(array_ops.rank(inp) - 1, [1]),
|
array_ops.reshape(array_ops.rank(inp) - 1, [1]),
|
||||||
dtype=dtypes.int32)
|
dtype=dtypes.int32)
|
||||||
],
|
], 0)
|
||||||
0)
|
|
||||||
# Subtract k * x, see comment above.
|
# Subtract k * x, see comment above.
|
||||||
cluster_center_updates -= math_ops.cast(
|
cluster_center_updates -= math_ops.cast(
|
||||||
array_ops.reshape(count_updates, broadcast_shape),
|
array_ops.reshape(count_updates, broadcast_shape),
|
||||||
|
|
@ -524,14 +509,10 @@ class KMeans(object):
|
||||||
# scale by 1 / (n + k), see comment above.
|
# scale by 1 / (n + k), see comment above.
|
||||||
cluster_center_updates *= learning_rate
|
cluster_center_updates *= learning_rate
|
||||||
# Apply the updates.
|
# Apply the updates.
|
||||||
update_counts = state_ops.scatter_add(
|
update_counts = state_ops.scatter_add(total_counts, unique_ids,
|
||||||
total_counts,
|
count_updates)
|
||||||
unique_ids,
|
|
||||||
count_updates)
|
|
||||||
update_cluster_centers = state_ops.scatter_add(
|
update_cluster_centers = state_ops.scatter_add(
|
||||||
cluster_centers,
|
cluster_centers, unique_ids, cluster_center_updates)
|
||||||
unique_ids,
|
|
||||||
cluster_center_updates)
|
|
||||||
update_ops.extend([update_counts, update_cluster_centers])
|
update_ops.extend([update_counts, update_cluster_centers])
|
||||||
return control_flow_ops.group(*update_ops)
|
return control_flow_ops.group(*update_ops)
|
||||||
|
|
||||||
|
|
@ -552,7 +533,7 @@ class KMeans(object):
|
||||||
cluster_counts = []
|
cluster_counts = []
|
||||||
epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
|
epsilon = constant_op.constant(1e-6, dtype=inputs[0].dtype)
|
||||||
for inp, cluster_idx in zip(inputs, cluster_idx_list):
|
for inp, cluster_idx in zip(inputs, cluster_idx_list):
|
||||||
with ops.colocate_with(inp):
|
with ops.colocate_with(inp, ignore_existing=True):
|
||||||
cluster_sums.append(
|
cluster_sums.append(
|
||||||
math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters))
|
math_ops.unsorted_segment_sum(inp, cluster_idx, self._num_clusters))
|
||||||
cluster_counts.append(
|
cluster_counts.append(
|
||||||
|
|
@ -561,7 +542,7 @@ class KMeans(object):
|
||||||
array_ops.ones(
|
array_ops.ones(
|
||||||
array_ops.reshape(array_ops.shape(inp)[0], [-1])),
|
array_ops.reshape(array_ops.shape(inp)[0], [-1])),
|
||||||
[-1, 1]), cluster_idx, self._num_clusters))
|
[-1, 1]), cluster_idx, self._num_clusters))
|
||||||
with ops.colocate_with(cluster_centers):
|
with ops.colocate_with(cluster_centers, ignore_existing=True):
|
||||||
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
|
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
|
||||||
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
|
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
|
||||||
if self._clusters_l2_normalized():
|
if self._clusters_l2_normalized():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user