mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Adds tf.nn.experimental.general_dropout which is similar to tf.random.experimental.stateless_dropout but accepts a custom sampler function, to be used in Keras' RandomGenerator, to avoid unnecessary seed generation and scrambling (i.e. a roundtrip from (key, counter) to seed and back) incurred by stateless_dropout.
PiperOrigin-RevId: 495937292
This commit is contained in:
parent
512e9b694e
commit
f361f1b71c
|
|
@ -114,6 +114,10 @@
|
|||
* `stream_executor`
|
||||
* Top level `stream_executor` directory has been deleted, users should use
|
||||
equivalent headers and targets under `compiler/xla/stream_executor`.
|
||||
* `tf.nn`
|
||||
* Added `tf.nn.experimental.general_dropout`, which is similar to
|
||||
`tf.random.experimental.stateless_dropout` but accepts a custom sampler
|
||||
function.
|
||||
|
||||
|
||||
# Thanks to our Contributors
|
||||
|
|
|
|||
|
|
@ -5647,6 +5647,77 @@ def stateless_dropout(x, rate, seed, rng_alg=None, noise_shape=None, name=None):
|
|||
default_name="stateless_dropout")
|
||||
|
||||
|
||||
@tf_export("nn.experimental.general_dropout")
|
||||
@dispatch.add_dispatch_support
|
||||
def general_dropout(x, rate, uniform_sampler, noise_shape=None, name=None):
|
||||
"""Computes dropout: randomly sets elements to zero to prevent overfitting.
|
||||
|
||||
Please see `tf.nn.experimental.stateless_dropout` for an overview
|
||||
of dropout.
|
||||
|
||||
Unlike `tf.nn.experimental.stateless_dropout`, here you can supply a
|
||||
custom sampler function `uniform_sampler` that (given a shape and a
|
||||
dtype) generates a random, `Uniform[0, 1)`-distributed tensor (of
|
||||
that shape and dtype). `uniform_sampler` can be
|
||||
e.g. `tf.random.stateless_random_uniform` or
|
||||
`tf.random.Generator.uniform`.
|
||||
|
||||
For example, if you are using `tf.random.Generator` to generate
|
||||
random numbers, you can use this code to do dropouts:
|
||||
|
||||
>>> g = tf.random.Generator.from_seed(7)
|
||||
>>> sampler = g.uniform
|
||||
>>> x = tf.constant([1.1, 2.2, 3.3, 4.4, 5.5])
|
||||
>>> rate = 0.5
|
||||
>>> tf.nn.experimental.general_dropout(x, rate, sampler)
|
||||
<tf.Tensor: shape=(5,), ..., numpy=array([ 0. , 4.4, 6.6, 8.8, 11. ], ...)>
|
||||
>>> tf.nn.experimental.general_dropout(x, rate, sampler)
|
||||
<tf.Tensor: shape=(5,), ..., numpy=array([2.2, 0. , 0. , 8.8, 0. ], ...)>
|
||||
|
||||
It has better performance than using
|
||||
`tf.nn.experimental.stateless_dropout` and
|
||||
`tf.random.Generator.make_seeds`:
|
||||
|
||||
>>> g = tf.random.Generator.from_seed(7)
|
||||
>>> x = tf.constant([1.1, 2.2, 3.3, 4.4, 5.5])
|
||||
>>> rate = 0.5
|
||||
>>> tf.nn.experimental.stateless_dropout(x, rate, g.make_seeds(1)[:, 0])
|
||||
<tf.Tensor: shape=(5,), ..., numpy=array([ 2.2, 4.4, 6.6, 0. , 11. ], ...)>
|
||||
>>> tf.nn.experimental.stateless_dropout(x, rate, g.make_seeds(1)[:, 0])
|
||||
<tf.Tensor: shape=(5,), ..., numpy=array([2.2, 0. , 6.6, 8.8, 0. ], ...>
|
||||
|
||||
because generating and consuming seeds cost extra
|
||||
computation. `tf.nn.experimental.general_dropout` can let you avoid
|
||||
them.
|
||||
|
||||
Args:
|
||||
x: A floating point tensor.
|
||||
rate: A scalar `Tensor` with the same type as x. The probability
|
||||
that each element is dropped. For example, setting rate=0.1 would drop
|
||||
10% of input elements.
|
||||
uniform_sampler: a callable of signature `(shape, dtype) ->
|
||||
Tensor[shape, dtype]`, used to generate a tensor of uniformly-distributed
|
||||
random numbers in the range `[0, 1)`, of the given shape and dtype.
|
||||
noise_shape: A 1-D integer `Tensor`, representing the
|
||||
shape for randomly generated keep/drop flags.
|
||||
name: A name for this operation.
|
||||
|
||||
Returns:
|
||||
A Tensor of the same shape and dtype of `x`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating point
|
||||
tensor. `rate=1` is disallowed, because the output would be all zeros,
|
||||
which is likely not what was intended.
|
||||
"""
|
||||
def dummy_rng_step():
|
||||
pass
|
||||
return _dropout(x=x, rate=rate, noise_shape=noise_shape,
|
||||
uniform_sampler=uniform_sampler,
|
||||
dummy_rng_step=dummy_rng_step, name=name,
|
||||
default_name="general_dropout")
|
||||
|
||||
|
||||
def _dropout(x, rate, noise_shape, uniform_sampler, dummy_rng_step, name,
|
||||
default_name):
|
||||
"""Shared implementation of the various dropout functions.
|
||||
|
|
@ -5657,7 +5728,7 @@ def _dropout(x, rate, noise_shape, uniform_sampler, dummy_rng_step, name,
|
|||
noise_shape: same as the namesake in `dropout_v2`.
|
||||
uniform_sampler: a callable of signature `(shape, dtype) ->
|
||||
Tensor`, used to generate a tensor of uniformly-distributed
|
||||
random numbers, of the given shape and dtype.
|
||||
random numbers in the range `[0, 1)`, of the given shape and dtype.
|
||||
dummy_rng_step: a callable of signature `() -> None`, to make a
|
||||
dummy RNG call in the fast path. In the fast path where rate is
|
||||
0, we don't need to generate random numbers, but some samplers
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import math
|
|||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
|
@ -35,6 +36,7 @@ from tensorflow.python.ops import nn
|
|||
from tensorflow.python.ops import nn_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import stateful_random_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||
|
|
@ -316,21 +318,28 @@ DROPOUT_FNS = [
|
|||
("stateful_v2", nn_ops.dropout_v2),
|
||||
("stateless", functools.partial(nn_ops.stateless_dropout, seed=(1, 2))),
|
||||
("stateless_philox", functools.partial(
|
||||
nn_ops.stateless_dropout, seed=(1, 2), rng_alg="philox"))]
|
||||
nn_ops.stateless_dropout, seed=(1, 2), rng_alg="philox")),
|
||||
("generator", functools.partial( # pylint: disable=g-long-lambda
|
||||
nn_ops.general_dropout, uniform_sampler=lambda shape, dtype: ( # pylint: disable=g-long-lambda
|
||||
stateful_random_ops.Generator.from_seed(1).uniform(
|
||||
shape=shape, dtype=dtype)))),
|
||||
]
|
||||
|
||||
|
||||
class DropoutTest(test_lib.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("_%s_%s_%s" % (case_name, use_noise_shape, keep_prob), dropout_fn, # pylint: disable=g-complex-comprehension
|
||||
use_noise_shape, keep_prob)
|
||||
("_%s_%s_%s" % (case_name, use_noise_shape, keep_prob), case_name, # pylint: disable=g-complex-comprehension
|
||||
dropout_fn, use_noise_shape, keep_prob)
|
||||
for keep_prob in [0.1, 0.5, 0.8]
|
||||
for use_noise_shape in ["no", "concrete", "partial"]
|
||||
for case_name, dropout_fn in DROPOUT_FNS)
|
||||
def testDropout(self, dropout_fn, use_noise_shape, keep_prob):
|
||||
def testDropout(self, case_name, dropout_fn, use_noise_shape, keep_prob):
|
||||
# Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
|
||||
# that it is producing approximately the right number of ones over a large
|
||||
# number of samples, based on the keep probability.
|
||||
if "generator" in case_name and not context.executing_eagerly():
|
||||
self.skipTest("tf.random.Generator can only be used in TF2.")
|
||||
if use_noise_shape == "no":
|
||||
x_dim = 70
|
||||
y_dim = 30
|
||||
|
|
@ -362,11 +371,13 @@ class DropoutTest(test_lib.TestCase, parameterized.TestCase):
|
|||
self.assertLess(rel_error, 0.15)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("_%s_%s" % (case_name, keep_prob), dropout_fn, keep_prob) # pylint: disable=g-complex-comprehension
|
||||
("_%s_%s" % (case_name, keep_prob), case_name, dropout_fn, keep_prob) # pylint: disable=g-complex-comprehension
|
||||
for keep_prob in [0.1, 0.5, 0.8]
|
||||
for case_name, dropout_fn in DROPOUT_FNS)
|
||||
def testShapedDropoutCorrelation(self, dropout_fn, keep_prob):
|
||||
def testShapedDropoutCorrelation(self, case_name, dropout_fn, keep_prob):
|
||||
# Runs a shaped dropout and tests that the correlations are correct.
|
||||
if "generator" in case_name and not context.executing_eagerly():
|
||||
self.skipTest("tf.random.Generator can only be used in TF2.")
|
||||
x_dim = 40
|
||||
y_dim = 30
|
||||
num_iter = 10
|
||||
|
|
@ -386,7 +397,6 @@ class DropoutTest(test_lib.TestCase, parameterized.TestCase):
|
|||
for use_keep_prob in [False, True]
|
||||
for keep_prob in [0.1, 0.5, 0.8]
|
||||
for case_name, dropout_fn in DROPOUT_FNS)
|
||||
@test_util.run_deprecated_v1
|
||||
def testDropoutPlaceholderRateAndKeepProb(self, case_name, dropout_fn,
|
||||
keep_prob, use_keep_prob):
|
||||
# Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
|
||||
|
|
@ -394,26 +404,26 @@ class DropoutTest(test_lib.TestCase, parameterized.TestCase):
|
|||
# number of samples, based on the keep probability.
|
||||
if use_keep_prob and case_name != "stateful_v1":
|
||||
self.skipTest("Only V1 `dropout` has the `keep_prob` argument.")
|
||||
if "generator" in case_name and not context.executing_eagerly():
|
||||
self.skipTest("tf.random.Generator can only be used in TF2.")
|
||||
x_dim = 70
|
||||
y_dim = 30
|
||||
num_iter = 10
|
||||
with self.cached_session():
|
||||
t = constant_op.constant(
|
||||
1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
|
||||
keep_prob_placeholder = array_ops.placeholder(dtypes.float32)
|
||||
t = constant_op.constant(
|
||||
1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
|
||||
final_count = 0
|
||||
for _ in range(0, num_iter):
|
||||
if use_keep_prob:
|
||||
dropout = dropout_fn(t, keep_prob=keep_prob_placeholder)
|
||||
dropout = dropout_fn(t, keep_prob=keep_prob)
|
||||
else:
|
||||
dropout = dropout_fn(t, rate=1 - keep_prob_placeholder)
|
||||
final_count = 0
|
||||
dropout = dropout_fn(t, rate=1 - keep_prob)
|
||||
self.assertEqual([x_dim, y_dim], dropout.get_shape())
|
||||
for _ in range(0, num_iter):
|
||||
value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob})
|
||||
final_count += np.count_nonzero(value)
|
||||
# Verifies that there are only two values: 0 and 1/keep_prob.
|
||||
sorted_value = np.unique(np.sort(value))
|
||||
self.assertEqual(0, sorted_value[0])
|
||||
self.assertAllClose(1 / keep_prob, sorted_value[1])
|
||||
value = self.evaluate(dropout)
|
||||
final_count += np.count_nonzero(value)
|
||||
# Verifies that there are only two values: 0 and 1/keep_prob.
|
||||
sorted_value = np.unique(np.sort(value))
|
||||
self.assertEqual(0, sorted_value[0])
|
||||
self.assertAllClose(1 / keep_prob, sorted_value[1])
|
||||
# Check that we are in the 15% error range
|
||||
expected_count = x_dim * y_dim * keep_prob * num_iter
|
||||
rel_error = math.fabs(final_count - expected_count) / expected_count
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
path: "tensorflow.nn.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "general_dropout"
|
||||
argspec: "args=[\'x\', \'rate\', \'uniform_sampler\', \'noise_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_dropout"
|
||||
argspec: "args=[\'x\', \'rate\', \'seed\', \'rng_alg\', \'noise_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
path: "tensorflow.nn.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "general_dropout"
|
||||
argspec: "args=[\'x\', \'rate\', \'uniform_sampler\', \'noise_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "stateless_dropout"
|
||||
argspec: "args=[\'x\', \'rate\', \'seed\', \'rng_alg\', \'noise_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user