diff --git a/aten/src/ATen/core/DistributionsHelper.h b/aten/src/ATen/core/DistributionsHelper.h index acca669503d..bbf8c648fca 100644 --- a/aten/src/ATen/core/DistributionsHelper.h +++ b/aten/src/ATen/core/DistributionsHelper.h @@ -40,7 +40,15 @@ struct uniform_int_from_to_distribution { template C10_HOST_DEVICE inline T operator()(RNG generator) { +#ifdef FBCODE_CAFFE2 + if (( + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) && range_ >= 1ULL << 32) +#else if (range_ >= 1ULL << 28) // allow approx 5% skew in uniform int generation using % +#endif { return transformation::uniform_int_from_to(generator->random64(), range_, base_); } else { diff --git a/aten/src/ATen/native/cuda/DistributionTemplates.h b/aten/src/ATen/native/cuda/DistributionTemplates.h index cbb1f076c57..1fc9195ac53 100644 --- a/aten/src/ATen/native/cuda/DistributionTemplates.h +++ b/aten/src/ATen/native/cuda/DistributionTemplates.h @@ -279,6 +279,41 @@ namespace cuda { template void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) { +#ifdef FBCODE_CAFFE2 + AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] { + if (( + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) && range >= 1ULL << 32) + { + // define lambda to mod with range and add base + auto random_func = [range, base] __device__ (uint64_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = curand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [range, base] __device__ (uint32_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 { + return curand4(state); + }, + random_func); + } + }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +#else AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] { if (range >= 1ULL << 28) // allow approx 5% skew in uniform int generation using % { @@ -308,6 +343,7 @@ void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t bas random_func); } }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +#endif } // This is the special kernel to handle single specific case: diff --git a/aten/src/ATen/test/rng_test.h b/aten/src/ATen/test/rng_test.h index a785163e8f8..49ca378b1be 100644 --- a/aten/src/ATen/test/rng_test.h +++ b/aten/src/ATen/test/rng_test.h @@ -137,9 +137,13 @@ void test_random_from_to(const at::Device& device) { range = static_cast(max_to) - static_cast(from) + 1; from_case_covered = true; } +#ifdef FBCODE_CAFFE2 + if (range < (1ULL << 32)) { +#else // this is leaking details of implementation into test // we are starting to use random64() at 2^28 to minimize skew due to % if (range < (1ULL << 28)) { +#endif exp = static_cast(static_cast((static_cast(val) % range + from))); } else { exp = static_cast(static_cast((val % range + from))); diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index cd05e4254c8..c94444fba57 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -3502,6 +3502,7 @@ class TestRandomTensorCreation(TestCase): self.assertTrue((res1 >= 0).all().item()) + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "For fb compatibility random not changed in fbcode") def test_randint_distribution(self, device): size = 1_000_000 n_max = int(0.75 * 2 ** 32)