Return a classifier score of the same type as the logits.

PiperOrigin-RevId: 174184871
This commit is contained in:
A. Unique TensorFlower 2017-11-01 08:45:07 -07:00 committed by TensorFlower Gardener
parent 9da02be116
commit 18bf5b2d91

View File

@ -297,7 +297,8 @@ def classifier_score(images, classifier_fn, num_batches=1):
efficiently run them through the classifier network.
Returns:
The classifier score. A floating-point scalar.
The classifier score. A floating-point scalar of the same type as the output
of `classifier_fn`.
"""
generated_images_list = array_ops.split(
images, num_or_size_splits=num_batches)
@ -316,7 +317,7 @@ def classifier_score(images, classifier_fn, num_batches=1):
# Use maximum precision for best results.
logits_dtype = logits.dtype
if logits_dtype != dtypes.float64:
logits = math_ops.cast(logits, dtypes.float64)
logits = math_ops.to_double(logits)
p = nn_ops.softmax(logits)
q = math_ops.reduce_mean(p, axis=0)
@ -326,7 +327,7 @@ def classifier_score(images, classifier_fn, num_batches=1):
final_score = math_ops.exp(log_score)
if logits_dtype != dtypes.float64:
final_score = math_ops.cast(final_score, dtypes.float64)
final_score = math_ops.cast(final_score, logits_dtype)
return final_score