[indexing] Prevent integer overflow from large step values in C++ (#161707)

Fixes https://github.com/pytorch/pytorch/issues/160868
hmmm, I found an existing fix PR after I've finished this one. For reference, the old PR was
https://github.com/pytorch/pytorch/pull/147433/files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161707
Approved by: https://github.com/leslie-fang-intel, https://github.com/CaoE, https://github.com/mlazos
This commit is contained in:
thenumberouscode 2025-09-12 03:16:23 +00:00 committed by PyTorch MergeBot
parent 7eb92b076f
commit c140bf217f
3 changed files with 28 additions and 2 deletions

View File

@ -3063,7 +3063,7 @@ Tensor slice(
}
auto storage_offset = self.storage_offset() + start_val * strides[dim];
auto len = end_val - start_val;
sizes[dim] = (len + step - 1) / step; // round-up
sizes[dim] = (len / step) + (len % step != 0); // safely round-up
strides[dim] *= step;
Tensor result;

View File

@ -4309,6 +4309,31 @@ class CommonTemplate:
self.assertEqual(torch.compile(fn)(x1, y), fn(x1, y))
self.assertEqual(torch.compile(fn)(x2, y), fn(x2, y))
def test_slice_copy(self):
class Model(nn.Module):
def __init__(self, start=449, step=(2**63 - 1)):
super().__init__()
self.start = start
self.step = step
def forward(self, x: torch.Tensor):
sliced = torch.slice_copy(
x, dim=0, start=self.start, end=None, step=self.step
)
return torch.reciprocal(sliced)
with config.patch({"implicit_fallbacks": True}):
# bad case
self.common(
Model(),
(torch.randn(875),),
)
# normal case
self.common(
Model(step=10),
(torch.randn(875),),
)
def test_slice1(self):
def fn(a):
return (

View File

@ -759,7 +759,8 @@ def slice_forward(
storage_offset = self.storage_offset() + start_val * strides[dim]
len = end_val - start_val
sizes[dim] = (len + step - 1) // step
# safely round-up for corresponding c++ impl
sizes[dim] = (len // step) + (1 if len % step != 0 else 0)
strides[dim] *= step
if self.is_quantized: