mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[indexing] Prevent integer overflow from large step values in C++ (#161707)
Fixes https://github.com/pytorch/pytorch/issues/160868 hmmm, I found an existing fix PR after I've finished this one. For reference, the old PR was https://github.com/pytorch/pytorch/pull/147433/files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161707 Approved by: https://github.com/leslie-fang-intel, https://github.com/CaoE, https://github.com/mlazos
This commit is contained in:
parent
7eb92b076f
commit
c140bf217f
|
|
@ -3063,7 +3063,7 @@ Tensor slice(
|
|||
}
|
||||
auto storage_offset = self.storage_offset() + start_val * strides[dim];
|
||||
auto len = end_val - start_val;
|
||||
sizes[dim] = (len + step - 1) / step; // round-up
|
||||
sizes[dim] = (len / step) + (len % step != 0); // safely round-up
|
||||
strides[dim] *= step;
|
||||
|
||||
Tensor result;
|
||||
|
|
|
|||
|
|
@ -4309,6 +4309,31 @@ class CommonTemplate:
|
|||
self.assertEqual(torch.compile(fn)(x1, y), fn(x1, y))
|
||||
self.assertEqual(torch.compile(fn)(x2, y), fn(x2, y))
|
||||
|
||||
def test_slice_copy(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, start=449, step=(2**63 - 1)):
|
||||
super().__init__()
|
||||
self.start = start
|
||||
self.step = step
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
sliced = torch.slice_copy(
|
||||
x, dim=0, start=self.start, end=None, step=self.step
|
||||
)
|
||||
return torch.reciprocal(sliced)
|
||||
|
||||
with config.patch({"implicit_fallbacks": True}):
|
||||
# bad case
|
||||
self.common(
|
||||
Model(),
|
||||
(torch.randn(875),),
|
||||
)
|
||||
# normal case
|
||||
self.common(
|
||||
Model(step=10),
|
||||
(torch.randn(875),),
|
||||
)
|
||||
|
||||
def test_slice1(self):
|
||||
def fn(a):
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -759,7 +759,8 @@ def slice_forward(
|
|||
|
||||
storage_offset = self.storage_offset() + start_val * strides[dim]
|
||||
len = end_val - start_val
|
||||
sizes[dim] = (len + step - 1) // step
|
||||
# safely round-up for corresponding c++ impl
|
||||
sizes[dim] = (len // step) + (1 if len % step != 0 else 0)
|
||||
strides[dim] *= step
|
||||
|
||||
if self.is_quantized:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user