mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Improve TFGAN documentation.
PiperOrigin-RevId: 170940188
This commit is contained in:
parent
0068086b9a
commit
4cf61262ae
|
|
@ -14,10 +14,41 @@
|
|||
# ==============================================================================
|
||||
"""TFGAN utilities for loss functions that accept GANModel namedtuples.
|
||||
|
||||
Example:
|
||||
The losses and penalties in this file all correspond to losses in
|
||||
`losses_impl.py`. Losses in that file take individual arguments, whereas in this
|
||||
file they take a `GANModel` tuple. For example:
|
||||
|
||||
losses_impl.py:
|
||||
```python
|
||||
# `tfgan.losses.args` losses take individual arguments.
|
||||
w_loss = tfgan.losses.args.wasserstein_discriminator_loss(
|
||||
def wasserstein_discriminator_loss(
|
||||
discriminator_real_outputs,
|
||||
discriminator_gen_outputs,
|
||||
real_weights=1.0,
|
||||
generated_weights=1.0,
|
||||
scope=None,
|
||||
loss_collection=ops.GraphKeys.LOSSES,
|
||||
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||
add_summaries=False)
|
||||
```
|
||||
|
||||
tuple_losses_impl.py:
|
||||
```python
|
||||
def wasserstein_discriminator_loss(
|
||||
gan_model,
|
||||
real_weights=1.0,
|
||||
generated_weights=1.0,
|
||||
scope=None,
|
||||
loss_collection=ops.GraphKeys.LOSSES,
|
||||
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||
add_summaries=False)
|
||||
```
|
||||
|
||||
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# `tfgan.losses.wargs` losses take individual arguments.
|
||||
w_loss = tfgan.losses.wargs.wasserstein_discriminator_loss(
|
||||
discriminator_real_outputs,
|
||||
discriminator_gen_outputs)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Named tuples for TFGAN."""
|
||||
"""Named tuples for TFGAN.
|
||||
|
||||
TFGAN training occurs in four steps, and each step communicates with the next
|
||||
step via one of these named tuples. At each step, you can either use a TFGAN
|
||||
helper function in `train.py`, or you can manually construct a tuple.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
|
|
|||
|
|
@ -14,7 +14,17 @@
|
|||
# ==============================================================================
|
||||
"""The TFGAN project provides a lightweight GAN training/testing framework.
|
||||
|
||||
See examples in `tensorflow_models` for details on how to use.
|
||||
This file contains the core helper functions to create and train a GAN model.
|
||||
See the README or examples in `tensorflow_models` for details on how to use.
|
||||
|
||||
TFGAN training occurs in four steps:
|
||||
1) Create a model
|
||||
2) Add a loss
|
||||
3) Create train ops
|
||||
4) Run the train ops
|
||||
|
||||
The functions in this file are organized around these four steps. Each function
|
||||
corresponds to one of the steps.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
|
@ -51,16 +61,6 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
|
||||
"""Convert input, list of inputs, or dictionary of inputs to Tensors."""
|
||||
if isinstance(tensor_or_l_or_d, (list, tuple)):
|
||||
return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
|
||||
elif isinstance(tensor_or_l_or_d, dict):
|
||||
return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
|
||||
else:
|
||||
return ops.convert_to_tensor(tensor_or_l_or_d)
|
||||
|
||||
|
||||
def gan_model(
|
||||
# Lambdas defining models.
|
||||
generator_fn,
|
||||
|
|
@ -133,20 +133,6 @@ def gan_model(
|
|||
discriminator_fn)
|
||||
|
||||
|
||||
def _validate_distributions(distributions_l, noise_l):
|
||||
if not isinstance(distributions_l, (tuple, list)):
|
||||
raise ValueError('`predicted_distributions` must be a list. Instead, found '
|
||||
'%s.' % type(distributions_l))
|
||||
for dist in distributions_l:
|
||||
if not isinstance(dist, ds.Distribution):
|
||||
raise ValueError('Every element in `predicted_distributions` must be a '
|
||||
'`tf.Distribution`. Instead, found %s.' % type(dist))
|
||||
if len(distributions_l) != len(noise_l):
|
||||
raise ValueError('Length of `predicted_distributions` %i must be the same '
|
||||
'as the length of structured noise %i.' %
|
||||
(len(distributions_l), len(noise_l)))
|
||||
|
||||
|
||||
def infogan_model(
|
||||
# Lambdas defining models.
|
||||
generator_fn,
|
||||
|
|
@ -231,16 +217,6 @@ def infogan_model(
|
|||
predicted_distributions)
|
||||
|
||||
|
||||
def _validate_acgan_discriminator_outputs(discriminator_output):
|
||||
try:
|
||||
a, b = discriminator_output
|
||||
except (TypeError, ValueError):
|
||||
raise TypeError(
|
||||
'A discriminator function for ACGAN must output a tuple '
|
||||
'consisting of (discrimination logits, classification logits).')
|
||||
return a, b
|
||||
|
||||
|
||||
def acgan_model(
|
||||
# Lambdas defining models.
|
||||
generator_fn,
|
||||
|
|
@ -252,6 +228,7 @@ def acgan_model(
|
|||
# Optional scopes.
|
||||
generator_scope='Generator',
|
||||
discriminator_scope='Discriminator',
|
||||
# Options.
|
||||
check_shapes=True):
|
||||
"""Returns an ACGANModel contains all the pieces needed for ACGAN training.
|
||||
|
||||
|
|
@ -497,11 +474,10 @@ def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
|
|||
|
||||
|
||||
def gan_train_ops(
|
||||
model, # GANModel
|
||||
loss, # GANLoss
|
||||
model,
|
||||
loss,
|
||||
generator_optimizer,
|
||||
discriminator_optimizer,
|
||||
# Optional check flags.
|
||||
check_for_unused_update_ops=True,
|
||||
# Optional args to pass directly to the `create_train_op`.
|
||||
**kwargs):
|
||||
|
|
@ -801,3 +777,40 @@ def get_sequential_train_steps(
|
|||
return gen_loss + dis_loss, should_stop
|
||||
|
||||
return sequential_train_steps
|
||||
|
||||
|
||||
# Helpers
|
||||
|
||||
|
||||
def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
|
||||
"""Convert input, list of inputs, or dictionary of inputs to Tensors."""
|
||||
if isinstance(tensor_or_l_or_d, (list, tuple)):
|
||||
return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
|
||||
elif isinstance(tensor_or_l_or_d, dict):
|
||||
return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
|
||||
else:
|
||||
return ops.convert_to_tensor(tensor_or_l_or_d)
|
||||
|
||||
|
||||
def _validate_distributions(distributions_l, noise_l):
|
||||
if not isinstance(distributions_l, (tuple, list)):
|
||||
raise ValueError('`predicted_distributions` must be a list. Instead, found '
|
||||
'%s.' % type(distributions_l))
|
||||
for dist in distributions_l:
|
||||
if not isinstance(dist, ds.Distribution):
|
||||
raise ValueError('Every element in `predicted_distributions` must be a '
|
||||
'`tf.Distribution`. Instead, found %s.' % type(dist))
|
||||
if len(distributions_l) != len(noise_l):
|
||||
raise ValueError('Length of `predicted_distributions` %i must be the same '
|
||||
'as the length of structured noise %i.' %
|
||||
(len(distributions_l), len(noise_l)))
|
||||
|
||||
|
||||
def _validate_acgan_discriminator_outputs(discriminator_output):
|
||||
try:
|
||||
a, b = discriminator_output
|
||||
except (TypeError, ValueError):
|
||||
raise TypeError(
|
||||
'A discriminator function for ACGAN must output a tuple '
|
||||
'consisting of (discrimination logits, classification logits).')
|
||||
return a, b
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user