mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Switch over python calls to RandomPoissonV2.
Part 2 of Support int32/64 in tf.random_poisson(). PiperOrigin-RevId: 174071745
This commit is contained in:
parent
b5d5326c62
commit
c911d0f169
|
|
@ -265,8 +265,6 @@ output: A tensor with shape `shape + shape(alpha)`. Each slice
|
||||||
`alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
|
`alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
// TODO(dhananayn): Deprecate RandomPoisson and switch over to RandomPoissonV2
|
|
||||||
// after forward compatibility period has passed.
|
|
||||||
REGISTER_OP("RandomPoisson")
|
REGISTER_OP("RandomPoisson")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.Input("shape: S")
|
.Input("shape: S")
|
||||||
|
|
@ -283,32 +281,9 @@ REGISTER_OP("RandomPoisson")
|
||||||
c->set_output(0, out);
|
c->set_output(0, out);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
|
.Deprecated(25, "Replaced by RandomPoissonV2")
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Outputs random values from the Poisson distribution(s) described by rate.
|
Use RandomPoissonV2 instead.
|
||||||
|
|
||||||
This op uses two algorithms, depending on rate. If rate >= 10, then
|
|
||||||
the algorithm by Hormann is used to acquire samples via
|
|
||||||
transformation-rejection.
|
|
||||||
See http://www.sciencedirect.com/science/article/pii/0167668793909974.
|
|
||||||
|
|
||||||
Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
|
|
||||||
random variables.
|
|
||||||
See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
|
|
||||||
Programming, Volume 2. Addison Wesley
|
|
||||||
|
|
||||||
shape: 1-D integer tensor. Shape of independent samples to draw from each
|
|
||||||
distribution described by the shape parameters given in rate.
|
|
||||||
rate: A tensor in which each scalar is a "rate" parameter describing the
|
|
||||||
associated poisson distribution.
|
|
||||||
seed: If either `seed` or `seed2` are set to be non-zero, the random number
|
|
||||||
generator is seeded by the given seed. Otherwise, it is seeded by a
|
|
||||||
random seed.
|
|
||||||
seed2: A second seed to avoid seed collision.
|
|
||||||
|
|
||||||
output: A tensor with shape `shape + shape(rate)`. Each slice
|
|
||||||
`[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
|
|
||||||
`rate[i0, i1, ...iN]`. The dtype of the output matches the dtype of
|
|
||||||
rate.
|
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("RandomPoissonV2")
|
REGISTER_OP("RandomPoissonV2")
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,7 @@ limitations under the License.
|
||||||
// 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2.
|
// 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2.
|
||||||
// 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017)
|
// 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017)
|
||||||
// 25. Deprecate stack (v1) ops in favor of v2 (2017/6/15).
|
// 25. Deprecate stack (v1) ops in favor of v2 (2017/6/15).
|
||||||
|
// 25. Deprecate RandomPoisson (v1) ops in favor of v2 (2017/10/25).
|
||||||
|
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||||
|
|
|
||||||
|
|
@ -24,11 +24,14 @@ from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_random_ops
|
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
|
|
||||||
|
# All supported dtypes for random_poisson().
|
||||||
|
_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64,
|
||||||
|
dtypes.int32, dtypes.int64)
|
||||||
|
|
||||||
|
|
||||||
class RandomPoissonTest(test.TestCase):
|
class RandomPoissonTest(test.TestCase):
|
||||||
"""This is a large test due to the moments computation taking some time."""
|
"""This is a large test due to the moments computation taking some time."""
|
||||||
|
|
@ -57,7 +60,7 @@ class RandomPoissonTest(test.TestCase):
|
||||||
# we want to tolerate. Since the z-test approximates a unit normal
|
# we want to tolerate. Since the z-test approximates a unit normal
|
||||||
# distribution, it should almost definitely never exceed 6.
|
# distribution, it should almost definitely never exceed 6.
|
||||||
z_limit = 6.0
|
z_limit = 6.0
|
||||||
for dt in dtypes.float16, dtypes.float32, dtypes.float64:
|
for dt in _SUPPORTED_DTYPES:
|
||||||
# Test when lam < 10 and when lam >= 10
|
# Test when lam < 10 and when lam >= 10
|
||||||
for stride in 0, 4, 10:
|
for stride in 0, 4, 10:
|
||||||
for lam in (3., 20):
|
for lam in (3., 20):
|
||||||
|
|
@ -102,7 +105,7 @@ class RandomPoissonTest(test.TestCase):
|
||||||
# Checks that the CPU and GPU implementation returns the same results,
|
# Checks that the CPU and GPU implementation returns the same results,
|
||||||
# given the same random seed
|
# given the same random seed
|
||||||
def testCPUGPUMatch(self):
|
def testCPUGPUMatch(self):
|
||||||
for dt in dtypes.float16, dtypes.float32, dtypes.float64:
|
for dt in _SUPPORTED_DTYPES:
|
||||||
results = {}
|
results = {}
|
||||||
for use_gpu in [False, True]:
|
for use_gpu in [False, True]:
|
||||||
sampler = self._Sampler(1000, 1.0, dt, use_gpu=use_gpu, seed=12345)
|
sampler = self._Sampler(1000, 1.0, dt, use_gpu=use_gpu, seed=12345)
|
||||||
|
|
@ -183,19 +186,11 @@ class RandomPoissonTest(test.TestCase):
|
||||||
|
|
||||||
def testDTypeCombinationsV2(self):
|
def testDTypeCombinationsV2(self):
|
||||||
"""Tests random_poisson_v2() for all supported dtype combinations."""
|
"""Tests random_poisson_v2() for all supported dtype combinations."""
|
||||||
# All supported dtypes by random_poisson_v2().
|
|
||||||
supported_dtypes = [
|
|
||||||
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
|
|
||||||
dtypes.int64
|
|
||||||
]
|
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
for lam_dt in supported_dtypes:
|
for lam_dt in _SUPPORTED_DTYPES:
|
||||||
for out_dt in supported_dtypes:
|
for out_dt in _SUPPORTED_DTYPES:
|
||||||
# TODO(dhananjayn): Change this to use random_poisson() after
|
random_ops.random_poisson(
|
||||||
# switching it to RandomPoissonV2.
|
constant_op.constant([1], dtype=lam_dt), [10],
|
||||||
gen_random_ops.random_poisson_v2(
|
|
||||||
[10], constant_op.constant([1], dtype=lam_dt),
|
|
||||||
dtype=out_dt).eval()
|
dtype=out_dt).eval()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -438,8 +438,8 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
|
||||||
distribution(s) to sample.
|
distribution(s) to sample.
|
||||||
shape: A 1-D integer Tensor or Python array. The shape of the output samples
|
shape: A 1-D integer Tensor or Python array. The shape of the output samples
|
||||||
to be drawn per "rate"-parameterized distribution.
|
to be drawn per "rate"-parameterized distribution.
|
||||||
dtype: The type of `lam` and the output: `float16`, `float32`, or
|
dtype: The type of the output: `float16`, `float32`, `float64`, `int32` or
|
||||||
`float64`.
|
`int64`.
|
||||||
seed: A Python integer. Used to create a random seed for the distributions.
|
seed: A Python integer. Used to create a random seed for the distributions.
|
||||||
See
|
See
|
||||||
@{tf.set_random_seed}
|
@{tf.set_random_seed}
|
||||||
|
|
@ -451,7 +451,7 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
|
||||||
values of type `dtype`.
|
values of type `dtype`.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "random_poisson", [lam, shape]):
|
with ops.name_scope(name, "random_poisson", [lam, shape]):
|
||||||
lam = ops.convert_to_tensor(lam, name="lam", dtype=dtype)
|
|
||||||
shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32)
|
shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32)
|
||||||
seed1, seed2 = random_seed.get_seed(seed)
|
seed1, seed2 = random_seed.get_seed(seed)
|
||||||
return gen_random_ops._random_poisson(shape, lam, seed=seed1, seed2=seed2)
|
return gen_random_ops.random_poisson_v2(
|
||||||
|
shape, lam, dtype=dtype, seed=seed1, seed2=seed2)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user