mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[part 1] Add support for int32 & int64 in RandomPoissonOp.
This computes int32/int64-precision poisson samples with double precision intermediate calculations (same as it's done for `half`) respectively. part 2 will switch over python calls to new op once forward compatibility period has passed. PiperOrigin-RevId: 171058336
This commit is contained in:
parent
70fc9bf9b6
commit
53cc63a2d9
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
|
@ -69,34 +70,42 @@ struct PoissonComputeType<Eigen::half> {
|
|||
typedef float ComputeType;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PoissonComputeType<int32> {
|
||||
typedef double ComputeType;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PoissonComputeType<int64> {
|
||||
typedef double ComputeType;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename U>
|
||||
struct PoissonFunctor {
|
||||
void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat,
|
||||
int num_rate, int num_samples,
|
||||
const random::PhiloxRandom& rng, T* samples_flat);
|
||||
const random::PhiloxRandom& rng, U* samples_flat);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct PoissonFunctor<CPUDevice, T> {
|
||||
template <typename T, typename U>
|
||||
struct PoissonFunctor<CPUDevice, T, U> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d, const T* rate_flat,
|
||||
int num_rate, int num_samples,
|
||||
const random::PhiloxRandom& rng, T* samples_flat) {
|
||||
const random::PhiloxRandom& rng, U* samples_flat) {
|
||||
// Two different algorithms are employed, depending on the size of
|
||||
// rate.
|
||||
// If rate < 10, we use an algorithm attributed to Knuth:
|
||||
// Seminumerical Algorithms. Art of Computer Programming, Volume 2.
|
||||
//
|
||||
// This algorithm runs in O(rate) time, and will require O(rate)
|
||||
// uniform
|
||||
// variates.
|
||||
// uniform variates.
|
||||
//
|
||||
// If rate >= 10 we use a transformation-rejection algorithm from
|
||||
// pairs
|
||||
// of uniform random variables due to Hormann.
|
||||
// pairs of uniform random variables due to Hormann.
|
||||
// http://www.sciencedirect.com/science/article/pii/0167668793909974
|
||||
//
|
||||
// The algorithm has an acceptance rate of ~89% for the smallest rate
|
||||
|
|
@ -154,8 +163,9 @@ struct PoissonFunctor<CPUDevice, T> {
|
|||
while (true) {
|
||||
UNIFORM(u);
|
||||
prod = prod * u;
|
||||
if (prod <= exp_neg_rate) {
|
||||
samples_rate_output[sample_idx * num_rate] = T(x);
|
||||
if (prod <= exp_neg_rate &&
|
||||
x <= CT(Eigen::NumTraits<U>::highest())) {
|
||||
samples_rate_output[sample_idx * num_rate] = U(x);
|
||||
break;
|
||||
}
|
||||
x += 1;
|
||||
|
|
@ -216,13 +226,18 @@ struct PoissonFunctor<CPUDevice, T> {
|
|||
CT k = Eigen::numext::floor((CT(2) * a / u_shifted + b) * u + rate +
|
||||
CT(0.43));
|
||||
|
||||
if (k > CT(Eigen::NumTraits<U>::highest())) {
|
||||
// retry in case of overflow.
|
||||
continue;
|
||||
}
|
||||
|
||||
// When alpha * f(G(U)) * G'(U) is close to 1, it is possible to
|
||||
// find a rectangle (-u_r, u_r) x (0, v_r) under the curve, such
|
||||
// that if v <= v_r and |u| <= u_r, then we can accept.
|
||||
// Here v_r = 0.9227 - 3.6224 / (b - 2) and u_r = 0.43.
|
||||
if (u_shifted >= CT(0.07) &&
|
||||
v <= CT(0.9277) - CT(3.6224) / (b - CT(2))) {
|
||||
samples_rate_output[sample_idx * num_rate] = T(k);
|
||||
samples_rate_output[sample_idx * num_rate] = U(k);
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -235,7 +250,7 @@ struct PoissonFunctor<CPUDevice, T> {
|
|||
CT s = log(v * inv_alpha / (a / (u_shifted * u_shifted) + b));
|
||||
CT t = -rate + k * log_rate - Eigen::numext::lgamma(k + 1);
|
||||
if (s <= t) {
|
||||
samples_rate_output[sample_idx * num_rate] = T(k);
|
||||
samples_rate_output[sample_idx * num_rate] = U(k);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -280,7 +295,7 @@ struct PoissonFunctor<CPUDevice, T> {
|
|||
namespace {
|
||||
|
||||
// Samples from one or more Poisson distributions.
|
||||
template <typename T>
|
||||
template <typename T, typename U>
|
||||
class RandomPoissonOp : public OpKernel {
|
||||
public:
|
||||
explicit RandomPoissonOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
|
|
@ -303,12 +318,12 @@ class RandomPoissonOp : public OpKernel {
|
|||
|
||||
const auto rate_flat = rate_t.flat<T>().data();
|
||||
const int64 num_rate = rate_t.NumElements();
|
||||
auto samples_flat = samples_t->flat<T>().data();
|
||||
auto samples_flat = samples_t->flat<U>().data();
|
||||
random::PhiloxRandom rng = generator_.ReserveRandomOutputs(
|
||||
num_samples * num_rate, kReservedSamplesPerOutput);
|
||||
|
||||
functor::PoissonFunctor<CPUDevice, T>()(ctx, ctx->eigen_device<CPUDevice>(),
|
||||
rate_flat, num_rate, num_samples,
|
||||
functor::PoissonFunctor<CPUDevice, T, U>()(
|
||||
ctx, ctx->eigen_device<CPUDevice>(), rate_flat, num_rate, num_samples,
|
||||
rng, samples_flat);
|
||||
}
|
||||
|
||||
|
|
@ -324,12 +339,34 @@ class RandomPoissonOp : public OpKernel {
|
|||
#define REGISTER(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("RandomPoisson").Device(DEVICE_CPU).TypeConstraint<TYPE>("dtype"), \
|
||||
RandomPoissonOp<TYPE>);
|
||||
RandomPoissonOp<TYPE, TYPE>);
|
||||
|
||||
TF_CALL_half(REGISTER);
|
||||
TF_CALL_float(REGISTER);
|
||||
TF_CALL_double(REGISTER);
|
||||
|
||||
#define REGISTER_V2(RTYPE, OTYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RandomPoissonV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<RTYPE>("R") \
|
||||
.TypeConstraint<OTYPE>("dtype"), \
|
||||
RandomPoissonOp<RTYPE, OTYPE>);
|
||||
|
||||
#define REGISTER_ALL(RTYPE) \
|
||||
REGISTER_V2(RTYPE, Eigen::half); \
|
||||
REGISTER_V2(RTYPE, float); \
|
||||
REGISTER_V2(RTYPE, double); \
|
||||
REGISTER_V2(RTYPE, int32); \
|
||||
REGISTER_V2(RTYPE, int64);
|
||||
|
||||
REGISTER_ALL(Eigen::half);
|
||||
REGISTER_ALL(float);
|
||||
REGISTER_ALL(double);
|
||||
REGISTER_ALL(int32);
|
||||
REGISTER_ALL(int64);
|
||||
|
||||
#undef REGISTER_ALL
|
||||
#undef REGISTER_V2
|
||||
#undef REGISTER
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ namespace tensorflow {
|
|||
namespace functor {
|
||||
|
||||
// Generic helper functor for the Random Poisson Op.
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T /* rate */, typename U /* output */>
|
||||
struct PoissonFunctor;
|
||||
|
||||
} // namespace functor
|
||||
|
|
|
|||
|
|
@ -265,6 +265,8 @@ output: A tensor with shape `shape + shape(alpha)`. Each slice
|
|||
`alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
|
||||
)doc");
|
||||
|
||||
// TODO(dhananayn): Deprecate RandomPoisson and switch over to RandomPoissonV2
|
||||
// after forward compatibility period has passed.
|
||||
REGISTER_OP("RandomPoisson")
|
||||
.SetIsStateful()
|
||||
.Input("shape: S")
|
||||
|
|
@ -309,4 +311,48 @@ output: A tensor with shape `shape + shape(rate)`. Each slice
|
|||
rate.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("RandomPoissonV2")
|
||||
.SetIsStateful()
|
||||
.Input("shape: S")
|
||||
.Input("rate: R")
|
||||
.Output("output: dtype")
|
||||
.Attr("seed: int = 0")
|
||||
.Attr("seed2: int = 0")
|
||||
.Attr("S: {int32, int64}")
|
||||
.Attr("R: {half, float, double, int32, int64} = DT_DOUBLE")
|
||||
.Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
|
||||
TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Outputs random values from the Poisson distribution(s) described by rate.
|
||||
|
||||
This op uses two algorithms, depending on rate. If rate >= 10, then
|
||||
the algorithm by Hormann is used to acquire samples via
|
||||
transformation-rejection.
|
||||
See http://www.sciencedirect.com/science/article/pii/0167668793909974.
|
||||
|
||||
Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
|
||||
random variables.
|
||||
See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
|
||||
Programming, Volume 2. Addison Wesley
|
||||
|
||||
shape: 1-D integer tensor. Shape of independent samples to draw from each
|
||||
distribution described by the shape parameters given in rate.
|
||||
rate: A tensor in which each scalar is a "rate" parameter describing the
|
||||
associated poisson distribution.
|
||||
seed: If either `seed` or `seed2` are set to be non-zero, the random number
|
||||
generator is seeded by the given seed. Otherwise, it is seeded by a
|
||||
random seed.
|
||||
seed2: A second seed to avoid seed collision.
|
||||
|
||||
output: A tensor with shape `shape + shape(rate)`. Each slice
|
||||
`[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
|
||||
`rate[i0, i1, ...iN]`.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -20,9 +20,11 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_random_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
|
@ -179,6 +181,23 @@ class RandomPoissonTest(test.TestCase):
|
|||
seed=12345)
|
||||
self.assertIs(None, rnd.get_shape().ndims)
|
||||
|
||||
def testDTypeCombinationsV2(self):
|
||||
"""Tests random_poisson_v2() for all supported dtype combinations."""
|
||||
# All supported dtypes by random_poisson_v2().
|
||||
supported_dtypes = [
|
||||
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
|
||||
dtypes.int64
|
||||
]
|
||||
|
||||
with self.test_session():
|
||||
for lam_dt in supported_dtypes:
|
||||
for out_dt in supported_dtypes:
|
||||
# TODO(dhananjayn): Change this to use random_poisson() after
|
||||
# switching it to RandomPoissonV2.
|
||||
gen_random_ops.random_poisson_v2(
|
||||
[10], constant_op.constant([1], dtype=lam_dt),
|
||||
dtype=out_dt).eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user