[part 1] Add support for int32 & int64 in RandomPoissonOp.

This computes int32/int64-precision poisson samples with double precision intermediate calculations (same as it's done for `half`) respectively.

part 2 will switch over python calls to new op once forward compatibility period has passed.

PiperOrigin-RevId: 171058336
This commit is contained in:
Dhananjay Nakrani 2017-10-04 13:57:18 -07:00 committed by TensorFlower Gardener
parent 70fc9bf9b6
commit 53cc63a2d9
4 changed files with 122 additions and 20 deletions

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <limits>
#include <memory> #include <memory>
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -69,34 +70,42 @@ struct PoissonComputeType<Eigen::half> {
typedef float ComputeType; typedef float ComputeType;
}; };
template <>
struct PoissonComputeType<int32> {
typedef double ComputeType;
};
template <>
struct PoissonComputeType<int64> {
typedef double ComputeType;
};
} // namespace } // namespace
namespace functor { namespace functor {
template <typename Device, typename T> template <typename Device, typename T, typename U>
struct PoissonFunctor { struct PoissonFunctor {
void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat, void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat,
int num_rate, int num_samples, int num_rate, int num_samples,
const random::PhiloxRandom& rng, T* samples_flat); const random::PhiloxRandom& rng, U* samples_flat);
}; };
template <typename T> template <typename T, typename U>
struct PoissonFunctor<CPUDevice, T> { struct PoissonFunctor<CPUDevice, T, U> {
void operator()(OpKernelContext* ctx, const CPUDevice& d, const T* rate_flat, void operator()(OpKernelContext* ctx, const CPUDevice& d, const T* rate_flat,
int num_rate, int num_samples, int num_rate, int num_samples,
const random::PhiloxRandom& rng, T* samples_flat) { const random::PhiloxRandom& rng, U* samples_flat) {
// Two different algorithms are employed, depending on the size of // Two different algorithms are employed, depending on the size of
// rate. // rate.
// If rate < 10, we use an algorithm attributed to Knuth: // If rate < 10, we use an algorithm attributed to Knuth:
// Seminumerical Algorithms. Art of Computer Programming, Volume 2. // Seminumerical Algorithms. Art of Computer Programming, Volume 2.
// //
// This algorithm runs in O(rate) time, and will require O(rate) // This algorithm runs in O(rate) time, and will require O(rate)
// uniform // uniform variates.
// variates.
// //
// If rate >= 10 we use a transformation-rejection algorithm from // If rate >= 10 we use a transformation-rejection algorithm from
// pairs // pairs of uniform random variables due to Hormann.
// of uniform random variables due to Hormann.
// http://www.sciencedirect.com/science/article/pii/0167668793909974 // http://www.sciencedirect.com/science/article/pii/0167668793909974
// //
// The algorithm has an acceptance rate of ~89% for the smallest rate // The algorithm has an acceptance rate of ~89% for the smallest rate
@ -154,8 +163,9 @@ struct PoissonFunctor<CPUDevice, T> {
while (true) { while (true) {
UNIFORM(u); UNIFORM(u);
prod = prod * u; prod = prod * u;
if (prod <= exp_neg_rate) { if (prod <= exp_neg_rate &&
samples_rate_output[sample_idx * num_rate] = T(x); x <= CT(Eigen::NumTraits<U>::highest())) {
samples_rate_output[sample_idx * num_rate] = U(x);
break; break;
} }
x += 1; x += 1;
@ -216,13 +226,18 @@ struct PoissonFunctor<CPUDevice, T> {
CT k = Eigen::numext::floor((CT(2) * a / u_shifted + b) * u + rate + CT k = Eigen::numext::floor((CT(2) * a / u_shifted + b) * u + rate +
CT(0.43)); CT(0.43));
if (k > CT(Eigen::NumTraits<U>::highest())) {
// retry in case of overflow.
continue;
}
// When alpha * f(G(U)) * G'(U) is close to 1, it is possible to // When alpha * f(G(U)) * G'(U) is close to 1, it is possible to
// find a rectangle (-u_r, u_r) x (0, v_r) under the curve, such // find a rectangle (-u_r, u_r) x (0, v_r) under the curve, such
// that if v <= v_r and |u| <= u_r, then we can accept. // that if v <= v_r and |u| <= u_r, then we can accept.
// Here v_r = 0.9227 - 3.6224 / (b - 2) and u_r = 0.43. // Here v_r = 0.9227 - 3.6224 / (b - 2) and u_r = 0.43.
if (u_shifted >= CT(0.07) && if (u_shifted >= CT(0.07) &&
v <= CT(0.9277) - CT(3.6224) / (b - CT(2))) { v <= CT(0.9277) - CT(3.6224) / (b - CT(2))) {
samples_rate_output[sample_idx * num_rate] = T(k); samples_rate_output[sample_idx * num_rate] = U(k);
break; break;
} }
@ -235,7 +250,7 @@ struct PoissonFunctor<CPUDevice, T> {
CT s = log(v * inv_alpha / (a / (u_shifted * u_shifted) + b)); CT s = log(v * inv_alpha / (a / (u_shifted * u_shifted) + b));
CT t = -rate + k * log_rate - Eigen::numext::lgamma(k + 1); CT t = -rate + k * log_rate - Eigen::numext::lgamma(k + 1);
if (s <= t) { if (s <= t) {
samples_rate_output[sample_idx * num_rate] = T(k); samples_rate_output[sample_idx * num_rate] = U(k);
break; break;
} }
} }
@ -280,7 +295,7 @@ struct PoissonFunctor<CPUDevice, T> {
namespace { namespace {
// Samples from one or more Poisson distributions. // Samples from one or more Poisson distributions.
template <typename T> template <typename T, typename U>
class RandomPoissonOp : public OpKernel { class RandomPoissonOp : public OpKernel {
public: public:
explicit RandomPoissonOp(OpKernelConstruction* context) : OpKernel(context) { explicit RandomPoissonOp(OpKernelConstruction* context) : OpKernel(context) {
@ -303,13 +318,13 @@ class RandomPoissonOp : public OpKernel {
const auto rate_flat = rate_t.flat<T>().data(); const auto rate_flat = rate_t.flat<T>().data();
const int64 num_rate = rate_t.NumElements(); const int64 num_rate = rate_t.NumElements();
auto samples_flat = samples_t->flat<T>().data(); auto samples_flat = samples_t->flat<U>().data();
random::PhiloxRandom rng = generator_.ReserveRandomOutputs( random::PhiloxRandom rng = generator_.ReserveRandomOutputs(
num_samples * num_rate, kReservedSamplesPerOutput); num_samples * num_rate, kReservedSamplesPerOutput);
functor::PoissonFunctor<CPUDevice, T>()(ctx, ctx->eigen_device<CPUDevice>(), functor::PoissonFunctor<CPUDevice, T, U>()(
rate_flat, num_rate, num_samples, ctx, ctx->eigen_device<CPUDevice>(), rate_flat, num_rate, num_samples,
rng, samples_flat); rng, samples_flat);
} }
private: private:
@ -324,12 +339,34 @@ class RandomPoissonOp : public OpKernel {
#define REGISTER(TYPE) \ #define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("RandomPoisson").Device(DEVICE_CPU).TypeConstraint<TYPE>("dtype"), \ Name("RandomPoisson").Device(DEVICE_CPU).TypeConstraint<TYPE>("dtype"), \
RandomPoissonOp<TYPE>); RandomPoissonOp<TYPE, TYPE>);
TF_CALL_half(REGISTER); TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER); TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER); TF_CALL_double(REGISTER);
#define REGISTER_V2(RTYPE, OTYPE) \
REGISTER_KERNEL_BUILDER(Name("RandomPoissonV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<RTYPE>("R") \
.TypeConstraint<OTYPE>("dtype"), \
RandomPoissonOp<RTYPE, OTYPE>);
#define REGISTER_ALL(RTYPE) \
REGISTER_V2(RTYPE, Eigen::half); \
REGISTER_V2(RTYPE, float); \
REGISTER_V2(RTYPE, double); \
REGISTER_V2(RTYPE, int32); \
REGISTER_V2(RTYPE, int64);
REGISTER_ALL(Eigen::half);
REGISTER_ALL(float);
REGISTER_ALL(double);
REGISTER_ALL(int32);
REGISTER_ALL(int64);
#undef REGISTER_ALL
#undef REGISTER_V2
#undef REGISTER #undef REGISTER
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -21,7 +21,7 @@ namespace tensorflow {
namespace functor { namespace functor {
// Generic helper functor for the Random Poisson Op. // Generic helper functor for the Random Poisson Op.
template <typename Device, typename T> template <typename Device, typename T /* rate */, typename U /* output */>
struct PoissonFunctor; struct PoissonFunctor;
} // namespace functor } // namespace functor

View File

@ -265,6 +265,8 @@ output: A tensor with shape `shape + shape(alpha)`. Each slice
`alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha. `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
)doc"); )doc");
// TODO(dhananayn): Deprecate RandomPoisson and switch over to RandomPoissonV2
// after forward compatibility period has passed.
REGISTER_OP("RandomPoisson") REGISTER_OP("RandomPoisson")
.SetIsStateful() .SetIsStateful()
.Input("shape: S") .Input("shape: S")
@ -309,4 +311,48 @@ output: A tensor with shape `shape + shape(rate)`. Each slice
rate. rate.
)doc"); )doc");
REGISTER_OP("RandomPoissonV2")
.SetIsStateful()
.Input("shape: S")
.Input("rate: R")
.Output("output: dtype")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Attr("S: {int32, int64}")
.Attr("R: {half, float, double, int32, int64} = DT_DOUBLE")
.Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
c->set_output(0, out);
return Status::OK();
})
.Doc(R"doc(
Outputs random values from the Poisson distribution(s) described by rate.
This op uses two algorithms, depending on rate. If rate >= 10, then
the algorithm by Hormann is used to acquire samples via
transformation-rejection.
See http://www.sciencedirect.com/science/article/pii/0167668793909974.
Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
random variables.
See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
Programming, Volume 2. Addison Wesley
shape: 1-D integer tensor. Shape of independent samples to draw from each
distribution described by the shape parameters given in rate.
rate: A tensor in which each scalar is a "rate" parameter describing the
associated poisson distribution.
seed: If either `seed` or `seed2` are set to be non-zero, the random number
generator is seeded by the given seed. Otherwise, it is seeded by a
random seed.
seed2: A second seed to avoid seed collision.
output: A tensor with shape `shape + shape(rate)`. Each slice
`[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
`rate[i0, i1, ...iN]`.
)doc");
} // namespace tensorflow } // namespace tensorflow

View File

@ -20,9 +20,11 @@ from __future__ import print_function
import numpy as np import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging from tensorflow.python.platform import tf_logging
@ -179,6 +181,23 @@ class RandomPoissonTest(test.TestCase):
seed=12345) seed=12345)
self.assertIs(None, rnd.get_shape().ndims) self.assertIs(None, rnd.get_shape().ndims)
def testDTypeCombinationsV2(self):
"""Tests random_poisson_v2() for all supported dtype combinations."""
# All supported dtypes by random_poisson_v2().
supported_dtypes = [
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
dtypes.int64
]
with self.test_session():
for lam_dt in supported_dtypes:
for out_dt in supported_dtypes:
# TODO(dhananjayn): Change this to use random_poisson() after
# switching it to RandomPoissonV2.
gen_random_ops.random_poisson_v2(
[10], constant_op.constant([1], dtype=lam_dt),
dtype=out_dt).eval()
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()