Return a classifier score of the same type as the logits.

PiperOrigin-RevId: 174184871
This commit is contained in:
A. Unique TensorFlower 2017-11-01 08:45:07 -07:00 committed by TensorFlower Gardener
parent 9da02be116
commit 18bf5b2d91

View File

@ -297,7 +297,8 @@ def classifier_score(images, classifier_fn, num_batches=1):
efficiently run them through the classifier network. efficiently run them through the classifier network.
Returns: Returns:
The classifier score. A floating-point scalar. The classifier score. A floating-point scalar of the same type as the output
of `classifier_fn`.
""" """
generated_images_list = array_ops.split( generated_images_list = array_ops.split(
images, num_or_size_splits=num_batches) images, num_or_size_splits=num_batches)
@ -316,7 +317,7 @@ def classifier_score(images, classifier_fn, num_batches=1):
# Use maximum precision for best results. # Use maximum precision for best results.
logits_dtype = logits.dtype logits_dtype = logits.dtype
if logits_dtype != dtypes.float64: if logits_dtype != dtypes.float64:
logits = math_ops.cast(logits, dtypes.float64) logits = math_ops.to_double(logits)
p = nn_ops.softmax(logits) p = nn_ops.softmax(logits)
q = math_ops.reduce_mean(p, axis=0) q = math_ops.reduce_mean(p, axis=0)
@ -326,7 +327,7 @@ def classifier_score(images, classifier_fn, num_batches=1):
final_score = math_ops.exp(log_score) final_score = math_ops.exp(log_score)
if logits_dtype != dtypes.float64: if logits_dtype != dtypes.float64:
final_score = math_ops.cast(final_score, dtypes.float64) final_score = math_ops.cast(final_score, logits_dtype)
return final_score return final_score