mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Fixes issue https://github.com/pytorch/pytorch/issues/25747 by upcasting to int64 before multiplication. Should be good enough for all reasonable nbins Pull Request resolved: https://github.com/pytorch/pytorch/pull/25748 Differential Revision: D17269111 Pulled By: ezyang fbshipit-source-id: 484be39080571203264a1bb9898ecf23d1aeafab
This commit is contained in:
parent
a7eaec6cf2
commit
ec8e75ea92
|
|
@ -17,7 +17,7 @@ namespace cuda {
|
|||
enum class CUDAHistogramMemoryType { SHARED, MULTI_BLOCK, GLOBAL };
|
||||
namespace {
|
||||
template<typename input_t, typename IndexType>
|
||||
__device__ static IndexType getBin(input_t bVal, input_t minvalue, input_t maxvalue, int nbins) {
|
||||
__device__ static IndexType getBin(input_t bVal, input_t minvalue, input_t maxvalue, int64_t nbins) {
|
||||
IndexType bin = (int)((bVal - minvalue) * nbins / (maxvalue - minvalue));
|
||||
// (only applicable for histc)
|
||||
// while each bin is inclusive at the lower end and exclusive at the higher, i.e. [start, end)
|
||||
|
|
@ -47,7 +47,7 @@ __global__ void kernelHistogram1D(
|
|||
detail::TensorInfo<output_t, IndexType> a, /* output */
|
||||
detail::TensorInfo<output_t, IndexType> p, /* partial output */
|
||||
detail::TensorInfo<input_t, IndexType> b, /* input */
|
||||
int nbins,
|
||||
int64_t nbins,
|
||||
input_t minvalue,
|
||||
input_t maxvalue,
|
||||
IndexType totalElements,
|
||||
|
|
|
|||
|
|
@ -2847,6 +2847,13 @@ class TestCuda(TestCase):
|
|||
self.assertEqual(t.cpu().bincount(), t.bincount())
|
||||
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
|
||||
|
||||
t = torch.zeros([10], dtype=torch.int32, device='cuda')
|
||||
# 35488 * 65536 as int32 would cause overflow to negative value
|
||||
# giving negative bin offset
|
||||
t[0] = 35488
|
||||
counted = t.bincount(minlength=65536)
|
||||
self.assertEqual(torch.sum(counted), 10)
|
||||
|
||||
def test_tiny_half_norm_(self):
|
||||
a = torch.arange(25).cuda().float()
|
||||
a /= 100000000
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user