mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[cuda] fix triu/tril int32 overflow for large matrices (#164705)
Fixes #136611 Cast blockIdx.x to int64_t before multiplication to prevent overflow when computing linear_idx for matrices larger than 2^31 elements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164705 Approved by: https://github.com/eqy, https://github.com/ngimel
This commit is contained in:
parent
ba93d5636e
commit
c1eda348be
|
|
@ -44,7 +44,7 @@ __global__ void triu_tril_kernel(
|
|||
const int64_t k,
|
||||
const int64_t N_padded,
|
||||
const IndexType last_dim_padded) {
|
||||
int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * elements_per_thread;
|
||||
int64_t linear_idx = (((int64_t)blockIdx.x) * blockDim.x + threadIdx.x) * elements_per_thread;
|
||||
if (linear_idx >= N_padded) {
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9931,6 +9931,28 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
|||
C = torch.matmul(A, B)
|
||||
self.assertEqual(C, B.sum().expand(B.shape))
|
||||
|
||||
@onlyCUDA
|
||||
@largeTensorTest("40GB")
|
||||
def test_triu_tril_large_matrix_64bit(self, device):
|
||||
"""
|
||||
Test triu/tril with large matrices requiring 64-bit indexing.
|
||||
Regression test for https://github.com/pytorch/pytorch/issues/136611
|
||||
"""
|
||||
# 100k x 100k matrix with 10B elements requires 64-bit indexing
|
||||
q_len = 100000
|
||||
causal_mask = torch.full((q_len, q_len), float('-inf'), device=device, dtype=torch.float32)
|
||||
causal_mask.triu_(1)
|
||||
|
||||
# Verify row 42950 is correct (previously failed due to int32 overflow at row*col)
|
||||
row_42950 = causal_mask[42950]
|
||||
num_zeros = (row_42950 == 0.0).sum().item()
|
||||
expected_zeros = 42951
|
||||
self.assertEqual(num_zeros, expected_zeros)
|
||||
|
||||
# Verify last row is correct
|
||||
last_row = causal_mask[-1]
|
||||
self.assertTrue((last_row == 0.0).all())
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
|
||||
def test_triu_tril_extreme_k_values(self, device, dtype):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user