mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Return a classifier score of the same type as the logits.
PiperOrigin-RevId: 174184871
This commit is contained in:
parent
9da02be116
commit
18bf5b2d91
|
|
@ -297,7 +297,8 @@ def classifier_score(images, classifier_fn, num_batches=1):
|
||||||
efficiently run them through the classifier network.
|
efficiently run them through the classifier network.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The classifier score. A floating-point scalar.
|
The classifier score. A floating-point scalar of the same type as the output
|
||||||
|
of `classifier_fn`.
|
||||||
"""
|
"""
|
||||||
generated_images_list = array_ops.split(
|
generated_images_list = array_ops.split(
|
||||||
images, num_or_size_splits=num_batches)
|
images, num_or_size_splits=num_batches)
|
||||||
|
|
@ -316,7 +317,7 @@ def classifier_score(images, classifier_fn, num_batches=1):
|
||||||
# Use maximum precision for best results.
|
# Use maximum precision for best results.
|
||||||
logits_dtype = logits.dtype
|
logits_dtype = logits.dtype
|
||||||
if logits_dtype != dtypes.float64:
|
if logits_dtype != dtypes.float64:
|
||||||
logits = math_ops.cast(logits, dtypes.float64)
|
logits = math_ops.to_double(logits)
|
||||||
|
|
||||||
p = nn_ops.softmax(logits)
|
p = nn_ops.softmax(logits)
|
||||||
q = math_ops.reduce_mean(p, axis=0)
|
q = math_ops.reduce_mean(p, axis=0)
|
||||||
|
|
@ -326,7 +327,7 @@ def classifier_score(images, classifier_fn, num_batches=1):
|
||||||
final_score = math_ops.exp(log_score)
|
final_score = math_ops.exp(log_score)
|
||||||
|
|
||||||
if logits_dtype != dtypes.float64:
|
if logits_dtype != dtypes.float64:
|
||||||
final_score = math_ops.cast(final_score, dtypes.float64)
|
final_score = math_ops.cast(final_score, logits_dtype)
|
||||||
return final_score
|
return final_score
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user