Improve TFGAN documentation.

PiperOrigin-RevId: 170940188
This commit is contained in:
A. Unique TensorFlower 2017-10-03 17:17:07 -07:00 committed by TensorFlower Gardener
parent 0068086b9a
commit 4cf61262ae
3 changed files with 91 additions and 42 deletions

View File

@ -14,10 +14,41 @@
# ============================================================================== # ==============================================================================
"""TFGAN utilities for loss functions that accept GANModel namedtuples. """TFGAN utilities for loss functions that accept GANModel namedtuples.
Example: The losses and penalties in this file all correspond to losses in
`losses_impl.py`. Losses in that file take individual arguments, whereas in this
file they take a `GANModel` tuple. For example:
losses_impl.py:
```python ```python
# `tfgan.losses.args` losses take individual arguments. def wasserstein_discriminator_loss(
w_loss = tfgan.losses.args.wasserstein_discriminator_loss( discriminator_real_outputs,
discriminator_gen_outputs,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False)
```
tuple_losses_impl.py:
```python
def wasserstein_discriminator_loss(
gan_model,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False)
```
Example usage:
```python
# `tfgan.losses.wargs` losses take individual arguments.
w_loss = tfgan.losses.wargs.wasserstein_discriminator_loss(
discriminator_real_outputs, discriminator_real_outputs,
discriminator_gen_outputs) discriminator_gen_outputs)

View File

@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Named tuples for TFGAN.""" """Named tuples for TFGAN.
TFGAN training occurs in four steps, and each step communicates with the next
step via one of these named tuples. At each step, you can either use a TFGAN
helper function in `train.py`, or you can manually construct a tuple.
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division

View File

@ -14,7 +14,17 @@
# ============================================================================== # ==============================================================================
"""The TFGAN project provides a lightweight GAN training/testing framework. """The TFGAN project provides a lightweight GAN training/testing framework.
See examples in `tensorflow_models` for details on how to use. This file contains the core helper functions to create and train a GAN model.
See the README or examples in `tensorflow_models` for details on how to use.
TFGAN training occurs in four steps:
1) Create a model
2) Add a loss
3) Create train ops
4) Run the train ops
The functions in this file are organized around these four steps. Each function
corresponds to one of the steps.
""" """
from __future__ import absolute_import from __future__ import absolute_import
@ -51,16 +61,6 @@ __all__ = [
] ]
def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
"""Convert input, list of inputs, or dictionary of inputs to Tensors."""
if isinstance(tensor_or_l_or_d, (list, tuple)):
return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
elif isinstance(tensor_or_l_or_d, dict):
return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
else:
return ops.convert_to_tensor(tensor_or_l_or_d)
def gan_model( def gan_model(
# Lambdas defining models. # Lambdas defining models.
generator_fn, generator_fn,
@ -133,20 +133,6 @@ def gan_model(
discriminator_fn) discriminator_fn)
def _validate_distributions(distributions_l, noise_l):
if not isinstance(distributions_l, (tuple, list)):
raise ValueError('`predicted_distributions` must be a list. Instead, found '
'%s.' % type(distributions_l))
for dist in distributions_l:
if not isinstance(dist, ds.Distribution):
raise ValueError('Every element in `predicted_distributions` must be a '
'`tf.Distribution`. Instead, found %s.' % type(dist))
if len(distributions_l) != len(noise_l):
raise ValueError('Length of `predicted_distributions` %i must be the same '
'as the length of structured noise %i.' %
(len(distributions_l), len(noise_l)))
def infogan_model( def infogan_model(
# Lambdas defining models. # Lambdas defining models.
generator_fn, generator_fn,
@ -231,16 +217,6 @@ def infogan_model(
predicted_distributions) predicted_distributions)
def _validate_acgan_discriminator_outputs(discriminator_output):
try:
a, b = discriminator_output
except (TypeError, ValueError):
raise TypeError(
'A discriminator function for ACGAN must output a tuple '
'consisting of (discrimination logits, classification logits).')
return a, b
def acgan_model( def acgan_model(
# Lambdas defining models. # Lambdas defining models.
generator_fn, generator_fn,
@ -252,6 +228,7 @@ def acgan_model(
# Optional scopes. # Optional scopes.
generator_scope='Generator', generator_scope='Generator',
discriminator_scope='Discriminator', discriminator_scope='Discriminator',
# Options.
check_shapes=True): check_shapes=True):
"""Returns an ACGANModel contains all the pieces needed for ACGAN training. """Returns an ACGANModel contains all the pieces needed for ACGAN training.
@ -497,11 +474,10 @@ def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
def gan_train_ops( def gan_train_ops(
model, # GANModel model,
loss, # GANLoss loss,
generator_optimizer, generator_optimizer,
discriminator_optimizer, discriminator_optimizer,
# Optional check flags.
check_for_unused_update_ops=True, check_for_unused_update_ops=True,
# Optional args to pass directly to the `create_train_op`. # Optional args to pass directly to the `create_train_op`.
**kwargs): **kwargs):
@ -801,3 +777,40 @@ def get_sequential_train_steps(
return gen_loss + dis_loss, should_stop return gen_loss + dis_loss, should_stop
return sequential_train_steps return sequential_train_steps
# Helpers
def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
"""Convert input, list of inputs, or dictionary of inputs to Tensors."""
if isinstance(tensor_or_l_or_d, (list, tuple)):
return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
elif isinstance(tensor_or_l_or_d, dict):
return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
else:
return ops.convert_to_tensor(tensor_or_l_or_d)
def _validate_distributions(distributions_l, noise_l):
if not isinstance(distributions_l, (tuple, list)):
raise ValueError('`predicted_distributions` must be a list. Instead, found '
'%s.' % type(distributions_l))
for dist in distributions_l:
if not isinstance(dist, ds.Distribution):
raise ValueError('Every element in `predicted_distributions` must be a '
'`tf.Distribution`. Instead, found %s.' % type(dist))
if len(distributions_l) != len(noise_l):
raise ValueError('Length of `predicted_distributions` %i must be the same '
'as the length of structured noise %i.' %
(len(distributions_l), len(noise_l)))
def _validate_acgan_discriminator_outputs(discriminator_output):
try:
a, b = discriminator_output
except (TypeError, ValueError):
raise TypeError(
'A discriminator function for ACGAN must output a tuple '
'consisting of (discrimination logits, classification logits).')
return a, b