mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix randint distribution for large max (#143787)
Fixes #ISSUE_NUMBER Similar to #143682, for large maximum values we were sampling integers via % and it doesn't provide uniform distribution. Here we limit the max skew to approx 1% (random32 is used for max values `<= 2**32 / 128`) This comes with significant perf penalty, especially for cuda, but it's a pretty bad bug, so we'll have to figure out what can be done to improve it. `torch.compile` has always been producing correct results for this, and it's performance is also significantly better than current eager (eager is ~660 GB/s on H100, torch.compile 1200 GB/s), so we have to figure out why torch.compile is better. `__launch_bounds__` slightly regress perf, so perhaps we can figure out how to specify them better, but it's only 20-30 GB/s, so the big difference is still unexplained. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143787 Approved by: https://github.com/eqy
This commit is contained in:
parent
1598d45879
commit
8059d56ec3
|
|
@ -40,11 +40,7 @@ struct uniform_int_from_to_distribution {
|
|||
|
||||
template <typename RNG>
|
||||
C10_HOST_DEVICE inline T operator()(RNG generator) {
|
||||
if ((
|
||||
std::is_same_v<T, int64_t> ||
|
||||
std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, float> ||
|
||||
std::is_same_v<T, at::BFloat16>) && range_ >= 1ULL << 32)
|
||||
if (range_ >= 1ULL << 25) // allow approx 1% skew in uniform int generation using %
|
||||
{
|
||||
return transformation::uniform_int_from_to<T>(generator->random64(), range_, base_);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -280,11 +280,7 @@ namespace cuda {
|
|||
template<typename RNG>
|
||||
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
|
||||
if ((
|
||||
std::is_same_v<scalar_t, int64_t> ||
|
||||
std::is_same_v<scalar_t, double> ||
|
||||
std::is_same_v<scalar_t, float> ||
|
||||
std::is_same_v<scalar_t, at::BFloat16>) && range >= 1ULL << 32)
|
||||
if (range >= 1ULL << 25) // allow approx 1% skew in uniform int generation using %
|
||||
{
|
||||
// define lambda to mod with range and add base
|
||||
auto random_func = [range, base] __device__ (uint64_t rand) {
|
||||
|
|
|
|||
|
|
@ -137,7 +137,9 @@ void test_random_from_to(const at::Device& device) {
|
|||
range = static_cast<uint64_t>(max_to) - static_cast<uint64_t>(from) + 1;
|
||||
from_case_covered = true;
|
||||
}
|
||||
if (range < (1ULL << 32)) {
|
||||
// this is leaking details of implementation into test
|
||||
// we are starting to use random64() at 2^25 to minimize skew due to %
|
||||
if (range < (1ULL << 25)) {
|
||||
exp = static_cast<T>(static_cast<int64_t>((static_cast<uint32_t>(val) % range + from)));
|
||||
} else {
|
||||
exp = static_cast<T>(static_cast<int64_t>((val % range + from)));
|
||||
|
|
|
|||
|
|
@ -8551,6 +8551,26 @@ class CommonTemplate:
|
|||
self.assertGreater(c0.max(), 2**40)
|
||||
self.assertLess(c0.max(), 2**50)
|
||||
|
||||
def test_randint_distribution(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def fn(n_argsmax, size):
|
||||
return torch.randint(n_max, (size,), device=self.device)
|
||||
|
||||
def bin(index, max_size):
|
||||
return index // (max_size // n_bins)
|
||||
|
||||
size = 1_000_000
|
||||
n_max = int(0.75 * 2**32)
|
||||
n_bins = 8
|
||||
|
||||
res = fn(n_max, size)
|
||||
bins = bin(res, n_max).float().cpu()
|
||||
hist, _ = bins.histogram(8, range=(0, n_bins))
|
||||
expected_bin = res.shape[0] / 8
|
||||
expected_error = math.sqrt(expected_bin) / expected_bin * 3
|
||||
error = (hist - expected_bin).abs().max() / expected_bin
|
||||
self.assertTrue(error < expected_error)
|
||||
|
||||
@config.patch(fallback_random=True)
|
||||
def test_like_rands(self):
|
||||
def fn(x):
|
||||
|
|
|
|||
|
|
@ -231,6 +231,7 @@ test_failures = {
|
|||
"test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")),
|
||||
"test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")),
|
||||
"test_polar_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True),
|
||||
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda",)),
|
||||
"test_randn_generator_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
"test_single_elem_dynamic_shapes": TestFailure(("cpu",)),
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ test_failures = {
|
|||
("cpu", "cuda", "xpu")
|
||||
),
|
||||
"test_conv_inference_heuristics_dynamic_shapes": TestFailure(("cuda", "xpu")),
|
||||
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda",)),
|
||||
}
|
||||
|
||||
if TEST_WITH_ROCM:
|
||||
|
|
|
|||
|
|
@ -3495,6 +3495,24 @@ class TestRandomTensorCreation(TestCase):
|
|||
self.assertTrue((res1 < 6).all().item())
|
||||
self.assertTrue((res1 >= 0).all().item())
|
||||
|
||||
|
||||
def test_randint_distribution(self, device):
|
||||
size = 1_000_000
|
||||
n_max = int(0.75 * 2 ** 32)
|
||||
n_bins = 8
|
||||
|
||||
def bin(index, max_size):
|
||||
return index // (max_size // n_bins)
|
||||
res = torch.randint(n_max, (size,), device=device)
|
||||
# histogram implemented for float only
|
||||
bins = bin(res, n_max).float().cpu()
|
||||
hist, _ = bins.histogram(8, range=(0, n_bins))
|
||||
expected_bin = res.shape[0] / 8
|
||||
expected_error = math.sqrt(expected_bin) / expected_bin * 3
|
||||
error = (hist - expected_bin).abs().max() / expected_bin
|
||||
self.assertTrue(error < expected_error)
|
||||
|
||||
|
||||
@dtypes(torch.half, torch.float, torch.bfloat16, torch.double,
|
||||
torch.complex32, torch.complex64, torch.complex128)
|
||||
def test_randn(self, device, dtype):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user