mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[part 1] Add support for int32 & int64 in RandomPoissonOp.
This computes int32/int64-precision poisson samples with double precision intermediate calculations (same as it's done for `half`) respectively. part 2 will switch over python calls to new op once forward compatibility period has passed. PiperOrigin-RevId: 171058336
This commit is contained in:
parent
70fc9bf9b6
commit
53cc63a2d9
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
@ -69,34 +70,42 @@ struct PoissonComputeType<Eigen::half> {
|
||||||
typedef float ComputeType;
|
typedef float ComputeType;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PoissonComputeType<int32> {
|
||||||
|
typedef double ComputeType;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PoissonComputeType<int64> {
|
||||||
|
typedef double ComputeType;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T, typename U>
|
||||||
struct PoissonFunctor {
|
struct PoissonFunctor {
|
||||||
void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat,
|
void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat,
|
||||||
int num_rate, int num_samples,
|
int num_rate, int num_samples,
|
||||||
const random::PhiloxRandom& rng, T* samples_flat);
|
const random::PhiloxRandom& rng, U* samples_flat);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename U>
|
||||||
struct PoissonFunctor<CPUDevice, T> {
|
struct PoissonFunctor<CPUDevice, T, U> {
|
||||||
void operator()(OpKernelContext* ctx, const CPUDevice& d, const T* rate_flat,
|
void operator()(OpKernelContext* ctx, const CPUDevice& d, const T* rate_flat,
|
||||||
int num_rate, int num_samples,
|
int num_rate, int num_samples,
|
||||||
const random::PhiloxRandom& rng, T* samples_flat) {
|
const random::PhiloxRandom& rng, U* samples_flat) {
|
||||||
// Two different algorithms are employed, depending on the size of
|
// Two different algorithms are employed, depending on the size of
|
||||||
// rate.
|
// rate.
|
||||||
// If rate < 10, we use an algorithm attributed to Knuth:
|
// If rate < 10, we use an algorithm attributed to Knuth:
|
||||||
// Seminumerical Algorithms. Art of Computer Programming, Volume 2.
|
// Seminumerical Algorithms. Art of Computer Programming, Volume 2.
|
||||||
//
|
//
|
||||||
// This algorithm runs in O(rate) time, and will require O(rate)
|
// This algorithm runs in O(rate) time, and will require O(rate)
|
||||||
// uniform
|
// uniform variates.
|
||||||
// variates.
|
|
||||||
//
|
//
|
||||||
// If rate >= 10 we use a transformation-rejection algorithm from
|
// If rate >= 10 we use a transformation-rejection algorithm from
|
||||||
// pairs
|
// pairs of uniform random variables due to Hormann.
|
||||||
// of uniform random variables due to Hormann.
|
|
||||||
// http://www.sciencedirect.com/science/article/pii/0167668793909974
|
// http://www.sciencedirect.com/science/article/pii/0167668793909974
|
||||||
//
|
//
|
||||||
// The algorithm has an acceptance rate of ~89% for the smallest rate
|
// The algorithm has an acceptance rate of ~89% for the smallest rate
|
||||||
|
|
@ -154,8 +163,9 @@ struct PoissonFunctor<CPUDevice, T> {
|
||||||
while (true) {
|
while (true) {
|
||||||
UNIFORM(u);
|
UNIFORM(u);
|
||||||
prod = prod * u;
|
prod = prod * u;
|
||||||
if (prod <= exp_neg_rate) {
|
if (prod <= exp_neg_rate &&
|
||||||
samples_rate_output[sample_idx * num_rate] = T(x);
|
x <= CT(Eigen::NumTraits<U>::highest())) {
|
||||||
|
samples_rate_output[sample_idx * num_rate] = U(x);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
x += 1;
|
x += 1;
|
||||||
|
|
@ -216,13 +226,18 @@ struct PoissonFunctor<CPUDevice, T> {
|
||||||
CT k = Eigen::numext::floor((CT(2) * a / u_shifted + b) * u + rate +
|
CT k = Eigen::numext::floor((CT(2) * a / u_shifted + b) * u + rate +
|
||||||
CT(0.43));
|
CT(0.43));
|
||||||
|
|
||||||
|
if (k > CT(Eigen::NumTraits<U>::highest())) {
|
||||||
|
// retry in case of overflow.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// When alpha * f(G(U)) * G'(U) is close to 1, it is possible to
|
// When alpha * f(G(U)) * G'(U) is close to 1, it is possible to
|
||||||
// find a rectangle (-u_r, u_r) x (0, v_r) under the curve, such
|
// find a rectangle (-u_r, u_r) x (0, v_r) under the curve, such
|
||||||
// that if v <= v_r and |u| <= u_r, then we can accept.
|
// that if v <= v_r and |u| <= u_r, then we can accept.
|
||||||
// Here v_r = 0.9227 - 3.6224 / (b - 2) and u_r = 0.43.
|
// Here v_r = 0.9227 - 3.6224 / (b - 2) and u_r = 0.43.
|
||||||
if (u_shifted >= CT(0.07) &&
|
if (u_shifted >= CT(0.07) &&
|
||||||
v <= CT(0.9277) - CT(3.6224) / (b - CT(2))) {
|
v <= CT(0.9277) - CT(3.6224) / (b - CT(2))) {
|
||||||
samples_rate_output[sample_idx * num_rate] = T(k);
|
samples_rate_output[sample_idx * num_rate] = U(k);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -235,7 +250,7 @@ struct PoissonFunctor<CPUDevice, T> {
|
||||||
CT s = log(v * inv_alpha / (a / (u_shifted * u_shifted) + b));
|
CT s = log(v * inv_alpha / (a / (u_shifted * u_shifted) + b));
|
||||||
CT t = -rate + k * log_rate - Eigen::numext::lgamma(k + 1);
|
CT t = -rate + k * log_rate - Eigen::numext::lgamma(k + 1);
|
||||||
if (s <= t) {
|
if (s <= t) {
|
||||||
samples_rate_output[sample_idx * num_rate] = T(k);
|
samples_rate_output[sample_idx * num_rate] = U(k);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -280,7 +295,7 @@ struct PoissonFunctor<CPUDevice, T> {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Samples from one or more Poisson distributions.
|
// Samples from one or more Poisson distributions.
|
||||||
template <typename T>
|
template <typename T, typename U>
|
||||||
class RandomPoissonOp : public OpKernel {
|
class RandomPoissonOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit RandomPoissonOp(OpKernelConstruction* context) : OpKernel(context) {
|
explicit RandomPoissonOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
|
|
@ -303,13 +318,13 @@ class RandomPoissonOp : public OpKernel {
|
||||||
|
|
||||||
const auto rate_flat = rate_t.flat<T>().data();
|
const auto rate_flat = rate_t.flat<T>().data();
|
||||||
const int64 num_rate = rate_t.NumElements();
|
const int64 num_rate = rate_t.NumElements();
|
||||||
auto samples_flat = samples_t->flat<T>().data();
|
auto samples_flat = samples_t->flat<U>().data();
|
||||||
random::PhiloxRandom rng = generator_.ReserveRandomOutputs(
|
random::PhiloxRandom rng = generator_.ReserveRandomOutputs(
|
||||||
num_samples * num_rate, kReservedSamplesPerOutput);
|
num_samples * num_rate, kReservedSamplesPerOutput);
|
||||||
|
|
||||||
functor::PoissonFunctor<CPUDevice, T>()(ctx, ctx->eigen_device<CPUDevice>(),
|
functor::PoissonFunctor<CPUDevice, T, U>()(
|
||||||
rate_flat, num_rate, num_samples,
|
ctx, ctx->eigen_device<CPUDevice>(), rate_flat, num_rate, num_samples,
|
||||||
rng, samples_flat);
|
rng, samples_flat);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
@ -324,12 +339,34 @@ class RandomPoissonOp : public OpKernel {
|
||||||
#define REGISTER(TYPE) \
|
#define REGISTER(TYPE) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("RandomPoisson").Device(DEVICE_CPU).TypeConstraint<TYPE>("dtype"), \
|
Name("RandomPoisson").Device(DEVICE_CPU).TypeConstraint<TYPE>("dtype"), \
|
||||||
RandomPoissonOp<TYPE>);
|
RandomPoissonOp<TYPE, TYPE>);
|
||||||
|
|
||||||
TF_CALL_half(REGISTER);
|
TF_CALL_half(REGISTER);
|
||||||
TF_CALL_float(REGISTER);
|
TF_CALL_float(REGISTER);
|
||||||
TF_CALL_double(REGISTER);
|
TF_CALL_double(REGISTER);
|
||||||
|
|
||||||
|
#define REGISTER_V2(RTYPE, OTYPE) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("RandomPoissonV2") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<RTYPE>("R") \
|
||||||
|
.TypeConstraint<OTYPE>("dtype"), \
|
||||||
|
RandomPoissonOp<RTYPE, OTYPE>);
|
||||||
|
|
||||||
|
#define REGISTER_ALL(RTYPE) \
|
||||||
|
REGISTER_V2(RTYPE, Eigen::half); \
|
||||||
|
REGISTER_V2(RTYPE, float); \
|
||||||
|
REGISTER_V2(RTYPE, double); \
|
||||||
|
REGISTER_V2(RTYPE, int32); \
|
||||||
|
REGISTER_V2(RTYPE, int64);
|
||||||
|
|
||||||
|
REGISTER_ALL(Eigen::half);
|
||||||
|
REGISTER_ALL(float);
|
||||||
|
REGISTER_ALL(double);
|
||||||
|
REGISTER_ALL(int32);
|
||||||
|
REGISTER_ALL(int64);
|
||||||
|
|
||||||
|
#undef REGISTER_ALL
|
||||||
|
#undef REGISTER_V2
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ namespace tensorflow {
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
// Generic helper functor for the Random Poisson Op.
|
// Generic helper functor for the Random Poisson Op.
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T /* rate */, typename U /* output */>
|
||||||
struct PoissonFunctor;
|
struct PoissonFunctor;
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
|
||||||
|
|
@ -265,6 +265,8 @@ output: A tensor with shape `shape + shape(alpha)`. Each slice
|
||||||
`alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
|
`alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
// TODO(dhananayn): Deprecate RandomPoisson and switch over to RandomPoissonV2
|
||||||
|
// after forward compatibility period has passed.
|
||||||
REGISTER_OP("RandomPoisson")
|
REGISTER_OP("RandomPoisson")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
.Input("shape: S")
|
.Input("shape: S")
|
||||||
|
|
@ -309,4 +311,48 @@ output: A tensor with shape `shape + shape(rate)`. Each slice
|
||||||
rate.
|
rate.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("RandomPoissonV2")
|
||||||
|
.SetIsStateful()
|
||||||
|
.Input("shape: S")
|
||||||
|
.Input("rate: R")
|
||||||
|
.Output("output: dtype")
|
||||||
|
.Attr("seed: int = 0")
|
||||||
|
.Attr("seed2: int = 0")
|
||||||
|
.Attr("S: {int32, int64}")
|
||||||
|
.Attr("R: {half, float, double, int32, int64} = DT_DOUBLE")
|
||||||
|
.Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle out;
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
|
||||||
|
TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
|
||||||
|
c->set_output(0, out);
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.Doc(R"doc(
|
||||||
|
Outputs random values from the Poisson distribution(s) described by rate.
|
||||||
|
|
||||||
|
This op uses two algorithms, depending on rate. If rate >= 10, then
|
||||||
|
the algorithm by Hormann is used to acquire samples via
|
||||||
|
transformation-rejection.
|
||||||
|
See http://www.sciencedirect.com/science/article/pii/0167668793909974.
|
||||||
|
|
||||||
|
Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
|
||||||
|
random variables.
|
||||||
|
See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
|
||||||
|
Programming, Volume 2. Addison Wesley
|
||||||
|
|
||||||
|
shape: 1-D integer tensor. Shape of independent samples to draw from each
|
||||||
|
distribution described by the shape parameters given in rate.
|
||||||
|
rate: A tensor in which each scalar is a "rate" parameter describing the
|
||||||
|
associated poisson distribution.
|
||||||
|
seed: If either `seed` or `seed2` are set to be non-zero, the random number
|
||||||
|
generator is seeded by the given seed. Otherwise, it is seeded by a
|
||||||
|
random seed.
|
||||||
|
seed2: A second seed to avoid seed collision.
|
||||||
|
|
||||||
|
output: A tensor with shape `shape + shape(rate)`. Each slice
|
||||||
|
`[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
|
||||||
|
`rate[i0, i1, ...iN]`.
|
||||||
|
)doc");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,11 @@ from __future__ import print_function
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gen_random_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
|
|
@ -179,6 +181,23 @@ class RandomPoissonTest(test.TestCase):
|
||||||
seed=12345)
|
seed=12345)
|
||||||
self.assertIs(None, rnd.get_shape().ndims)
|
self.assertIs(None, rnd.get_shape().ndims)
|
||||||
|
|
||||||
|
def testDTypeCombinationsV2(self):
|
||||||
|
"""Tests random_poisson_v2() for all supported dtype combinations."""
|
||||||
|
# All supported dtypes by random_poisson_v2().
|
||||||
|
supported_dtypes = [
|
||||||
|
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
|
||||||
|
dtypes.int64
|
||||||
|
]
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
for lam_dt in supported_dtypes:
|
||||||
|
for out_dt in supported_dtypes:
|
||||||
|
# TODO(dhananjayn): Change this to use random_poisson() after
|
||||||
|
# switching it to RandomPoissonV2.
|
||||||
|
gen_random_ops.random_poisson_v2(
|
||||||
|
[10], constant_op.constant([1], dtype=lam_dt),
|
||||||
|
dtype=out_dt).eval()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user