pytorch/caffe2/core/distributions_stubs.h
Tristan Rice 0c9787c758 caffe2: use at::mt19937 instead of std::mt19937 (10x speedup) (#43987)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43987

This replaces the caffe2 CPU random number (std::mt19937) with at::mt19937 which is the one currently used in pytorch. The ATen RNG is 10x faster than the std one and appears to be more robust given bugs in the std (https://fburl.com/diffusion/uhro7lqb)

For large embedding tables (10GB+) we see UniformFillOp taking upwards of 10 minutes as we're bottlenecked on the single threaded RNG. Swapping to at::mt19937 cuts that time to 10% of the current.

Test Plan: Ran all relevant tests + CI. This doesn't introduce new features (+ is a core change) so existing tests+CI should be sufficient to catch regressions.

Reviewed By: dzhulgakov

Differential Revision: D23219710

fbshipit-source-id: bd16ed6415b2933e047bcb283a013d47fb395814
2020-10-16 16:08:35 -07:00

76 lines
2.1 KiB
C++

#ifndef CAFFE2_CORE_DISTRIBUTIONS_STUBS_H_
#define CAFFE2_CORE_DISTRIBUTIONS_STUBS_H_
#include <c10/macros/Macros.h>
/**
* This file provides distributions compatible with
* ATen/core/DistributionsHelper.h but backed with the std RNG implementation
* instead of the ATen one.
*
* Caffe2 mobile builds currently do not depend on all of ATen so this is
* required to allow using the faster ATen RNG for normal builds but keep the
* build size small on mobile. RNG performance typically doesn't matter on
* mobile builds since the models are small and rarely using random
* initialization.
*/
namespace at {
namespace {
template <typename R, typename T>
struct distribution_adapter {
template <typename... Args>
C10_HOST_DEVICE inline distribution_adapter(Args... args)
: distribution_(std::forward<Args>(args)...) {}
template <typename RNG>
C10_HOST_DEVICE inline R operator()(RNG generator) {
return distribution_(*generator);
}
private:
T distribution_;
};
template <typename T>
struct uniform_int_from_to_distribution
: distribution_adapter<T, std::uniform_int_distribution<T>> {
C10_HOST_DEVICE inline uniform_int_from_to_distribution(
uint64_t range,
int64_t base)
: distribution_adapter<T, std::uniform_int_distribution<T>>(
base,
// std is inclusive, at is exclusive
base + range - 1) {}
};
template <typename T>
using uniform_real_distribution =
distribution_adapter<T, std::uniform_real_distribution<T>>;
template <typename T>
using normal_distribution =
distribution_adapter<T, std::normal_distribution<T>>;
template <typename T>
using bernoulli_distribution =
distribution_adapter<T, std::bernoulli_distribution>;
template <typename T>
using exponential_distribution =
distribution_adapter<T, std::exponential_distribution<T>>;
template <typename T>
using cauchy_distribution =
distribution_adapter<T, std::cauchy_distribution<T>>;
template <typename T>
using lognormal_distribution =
distribution_adapter<T, std::lognormal_distribution<T>>;
} // namespace
} // namespace at
#endif // CAFFE2_CORE_DISTRIBUTIONS_STUBS_H_