mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Return a classifier score of the same type as the logits.
PiperOrigin-RevId: 174184871
This commit is contained in:
parent
9da02be116
commit
18bf5b2d91
|
|
@ -297,7 +297,8 @@ def classifier_score(images, classifier_fn, num_batches=1):
|
|||
efficiently run them through the classifier network.
|
||||
|
||||
Returns:
|
||||
The classifier score. A floating-point scalar.
|
||||
The classifier score. A floating-point scalar of the same type as the output
|
||||
of `classifier_fn`.
|
||||
"""
|
||||
generated_images_list = array_ops.split(
|
||||
images, num_or_size_splits=num_batches)
|
||||
|
|
@ -316,7 +317,7 @@ def classifier_score(images, classifier_fn, num_batches=1):
|
|||
# Use maximum precision for best results.
|
||||
logits_dtype = logits.dtype
|
||||
if logits_dtype != dtypes.float64:
|
||||
logits = math_ops.cast(logits, dtypes.float64)
|
||||
logits = math_ops.to_double(logits)
|
||||
|
||||
p = nn_ops.softmax(logits)
|
||||
q = math_ops.reduce_mean(p, axis=0)
|
||||
|
|
@ -326,7 +327,7 @@ def classifier_score(images, classifier_fn, num_batches=1):
|
|||
final_score = math_ops.exp(log_score)
|
||||
|
||||
if logits_dtype != dtypes.float64:
|
||||
final_score = math_ops.cast(final_score, dtypes.float64)
|
||||
final_score = math_ops.cast(final_score, logits_dtype)
|
||||
return final_score
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user