mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Improve TFGAN documentation.
PiperOrigin-RevId: 170940188
This commit is contained in:
parent
0068086b9a
commit
4cf61262ae
|
|
@ -14,10 +14,41 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""TFGAN utilities for loss functions that accept GANModel namedtuples.
|
"""TFGAN utilities for loss functions that accept GANModel namedtuples.
|
||||||
|
|
||||||
Example:
|
The losses and penalties in this file all correspond to losses in
|
||||||
|
`losses_impl.py`. Losses in that file take individual arguments, whereas in this
|
||||||
|
file they take a `GANModel` tuple. For example:
|
||||||
|
|
||||||
|
losses_impl.py:
|
||||||
```python
|
```python
|
||||||
# `tfgan.losses.args` losses take individual arguments.
|
def wasserstein_discriminator_loss(
|
||||||
w_loss = tfgan.losses.args.wasserstein_discriminator_loss(
|
discriminator_real_outputs,
|
||||||
|
discriminator_gen_outputs,
|
||||||
|
real_weights=1.0,
|
||||||
|
generated_weights=1.0,
|
||||||
|
scope=None,
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
tuple_losses_impl.py:
|
||||||
|
```python
|
||||||
|
def wasserstein_discriminator_loss(
|
||||||
|
gan_model,
|
||||||
|
real_weights=1.0,
|
||||||
|
generated_weights=1.0,
|
||||||
|
scope=None,
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
# `tfgan.losses.wargs` losses take individual arguments.
|
||||||
|
w_loss = tfgan.losses.wargs.wasserstein_discriminator_loss(
|
||||||
discriminator_real_outputs,
|
discriminator_real_outputs,
|
||||||
discriminator_gen_outputs)
|
discriminator_gen_outputs)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Named tuples for TFGAN."""
|
"""Named tuples for TFGAN.
|
||||||
|
|
||||||
|
TFGAN training occurs in four steps, and each step communicates with the next
|
||||||
|
step via one of these named tuples. At each step, you can either use a TFGAN
|
||||||
|
helper function in `train.py`, or you can manually construct a tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,17 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""The TFGAN project provides a lightweight GAN training/testing framework.
|
"""The TFGAN project provides a lightweight GAN training/testing framework.
|
||||||
|
|
||||||
See examples in `tensorflow_models` for details on how to use.
|
This file contains the core helper functions to create and train a GAN model.
|
||||||
|
See the README or examples in `tensorflow_models` for details on how to use.
|
||||||
|
|
||||||
|
TFGAN training occurs in four steps:
|
||||||
|
1) Create a model
|
||||||
|
2) Add a loss
|
||||||
|
3) Create train ops
|
||||||
|
4) Run the train ops
|
||||||
|
|
||||||
|
The functions in this file are organized around these four steps. Each function
|
||||||
|
corresponds to one of the steps.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
@ -51,16 +61,6 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
|
|
||||||
"""Convert input, list of inputs, or dictionary of inputs to Tensors."""
|
|
||||||
if isinstance(tensor_or_l_or_d, (list, tuple)):
|
|
||||||
return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
|
|
||||||
elif isinstance(tensor_or_l_or_d, dict):
|
|
||||||
return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
|
|
||||||
else:
|
|
||||||
return ops.convert_to_tensor(tensor_or_l_or_d)
|
|
||||||
|
|
||||||
|
|
||||||
def gan_model(
|
def gan_model(
|
||||||
# Lambdas defining models.
|
# Lambdas defining models.
|
||||||
generator_fn,
|
generator_fn,
|
||||||
|
|
@ -133,20 +133,6 @@ def gan_model(
|
||||||
discriminator_fn)
|
discriminator_fn)
|
||||||
|
|
||||||
|
|
||||||
def _validate_distributions(distributions_l, noise_l):
|
|
||||||
if not isinstance(distributions_l, (tuple, list)):
|
|
||||||
raise ValueError('`predicted_distributions` must be a list. Instead, found '
|
|
||||||
'%s.' % type(distributions_l))
|
|
||||||
for dist in distributions_l:
|
|
||||||
if not isinstance(dist, ds.Distribution):
|
|
||||||
raise ValueError('Every element in `predicted_distributions` must be a '
|
|
||||||
'`tf.Distribution`. Instead, found %s.' % type(dist))
|
|
||||||
if len(distributions_l) != len(noise_l):
|
|
||||||
raise ValueError('Length of `predicted_distributions` %i must be the same '
|
|
||||||
'as the length of structured noise %i.' %
|
|
||||||
(len(distributions_l), len(noise_l)))
|
|
||||||
|
|
||||||
|
|
||||||
def infogan_model(
|
def infogan_model(
|
||||||
# Lambdas defining models.
|
# Lambdas defining models.
|
||||||
generator_fn,
|
generator_fn,
|
||||||
|
|
@ -231,16 +217,6 @@ def infogan_model(
|
||||||
predicted_distributions)
|
predicted_distributions)
|
||||||
|
|
||||||
|
|
||||||
def _validate_acgan_discriminator_outputs(discriminator_output):
|
|
||||||
try:
|
|
||||||
a, b = discriminator_output
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
raise TypeError(
|
|
||||||
'A discriminator function for ACGAN must output a tuple '
|
|
||||||
'consisting of (discrimination logits, classification logits).')
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
|
|
||||||
def acgan_model(
|
def acgan_model(
|
||||||
# Lambdas defining models.
|
# Lambdas defining models.
|
||||||
generator_fn,
|
generator_fn,
|
||||||
|
|
@ -252,6 +228,7 @@ def acgan_model(
|
||||||
# Optional scopes.
|
# Optional scopes.
|
||||||
generator_scope='Generator',
|
generator_scope='Generator',
|
||||||
discriminator_scope='Discriminator',
|
discriminator_scope='Discriminator',
|
||||||
|
# Options.
|
||||||
check_shapes=True):
|
check_shapes=True):
|
||||||
"""Returns an ACGANModel contains all the pieces needed for ACGAN training.
|
"""Returns an ACGANModel contains all the pieces needed for ACGAN training.
|
||||||
|
|
||||||
|
|
@ -497,11 +474,10 @@ def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
|
||||||
|
|
||||||
|
|
||||||
def gan_train_ops(
|
def gan_train_ops(
|
||||||
model, # GANModel
|
model,
|
||||||
loss, # GANLoss
|
loss,
|
||||||
generator_optimizer,
|
generator_optimizer,
|
||||||
discriminator_optimizer,
|
discriminator_optimizer,
|
||||||
# Optional check flags.
|
|
||||||
check_for_unused_update_ops=True,
|
check_for_unused_update_ops=True,
|
||||||
# Optional args to pass directly to the `create_train_op`.
|
# Optional args to pass directly to the `create_train_op`.
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
@ -801,3 +777,40 @@ def get_sequential_train_steps(
|
||||||
return gen_loss + dis_loss, should_stop
|
return gen_loss + dis_loss, should_stop
|
||||||
|
|
||||||
return sequential_train_steps
|
return sequential_train_steps
|
||||||
|
|
||||||
|
|
||||||
|
# Helpers
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
|
||||||
|
"""Convert input, list of inputs, or dictionary of inputs to Tensors."""
|
||||||
|
if isinstance(tensor_or_l_or_d, (list, tuple)):
|
||||||
|
return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
|
||||||
|
elif isinstance(tensor_or_l_or_d, dict):
|
||||||
|
return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
|
||||||
|
else:
|
||||||
|
return ops.convert_to_tensor(tensor_or_l_or_d)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_distributions(distributions_l, noise_l):
|
||||||
|
if not isinstance(distributions_l, (tuple, list)):
|
||||||
|
raise ValueError('`predicted_distributions` must be a list. Instead, found '
|
||||||
|
'%s.' % type(distributions_l))
|
||||||
|
for dist in distributions_l:
|
||||||
|
if not isinstance(dist, ds.Distribution):
|
||||||
|
raise ValueError('Every element in `predicted_distributions` must be a '
|
||||||
|
'`tf.Distribution`. Instead, found %s.' % type(dist))
|
||||||
|
if len(distributions_l) != len(noise_l):
|
||||||
|
raise ValueError('Length of `predicted_distributions` %i must be the same '
|
||||||
|
'as the length of structured noise %i.' %
|
||||||
|
(len(distributions_l), len(noise_l)))
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_acgan_discriminator_outputs(discriminator_output):
|
||||||
|
try:
|
||||||
|
a, b = discriminator_output
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
raise TypeError(
|
||||||
|
'A discriminator function for ACGAN must output a tuple '
|
||||||
|
'consisting of (discrimination logits, classification logits).')
|
||||||
|
return a, b
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user