remove check_is_size from test_misc.py (#164667)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164667
Approved by: https://github.com/angelayi
ghstack dependencies: #164664, #164665
This commit is contained in:
Laith Sakka 2025-10-06 16:44:37 -07:00 committed by PyTorch MergeBot
parent cdaaf3e4a3
commit ef7e2ca77e
2 changed files with 7 additions and 15 deletions

View File

@ -1253,7 +1253,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
def test_bound_shape_checks(self):
def f1(x, y):
b = x.item()
torch._check_is_size(b)
torch._check(b >= 0)
torch._check(b < y.shape[0])
return y[:b]
@ -1276,7 +1276,6 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
@torch.compile(fullgraph=True)
def f(x):
y = x.item()
torch._check_is_size(y)
r = torch.arange(y, dtype=torch.float32)
if r.size(0) == y:
@ -1323,13 +1322,13 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
@torch._dynamo.config.patch(capture_scalar_outputs=True)
# Translation validation changes the exception type, don't run with it
@torch.fx.experimental._config.patch(translation_validation=False)
def test_torch_check_is_size(self):
def test_torch_check_nonnegative(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def f(x):
y = x.item()
torch._check_is_size(y)
torch._check(y >= 0)
# Cannot conditional on unbacked SymInt
if y == 0:
assert False
@ -8071,7 +8070,6 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
@torch.compile(fullgraph=True)
def f(x):
u0, u1 = x.tolist()
torch._check_is_size(u0)
# The condition should fold to true.
if ((u0 + 10) * (u0 + 10)) % (u0 + 10) == 0:
return torch.tensor(True)
@ -9226,14 +9224,10 @@ def ___make_guard_fn():
@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
)
def test_unbacked_symint(self):
def test_unbacked_symint_split(self):
@torch.compile(backend="eager")
def f(lengths, values):
sizes = lengths.tolist()
for s in sizes:
torch._check_is_size(s)
torch._check(s >= 2)
torch._check(s <= 100)
return torch.split(values, sizes)
f(torch.tensor([2, 3, 4]), torch.randn(9))
@ -11319,15 +11313,12 @@ fn
self.assertEqual(len(c2), 0)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_guard_size_oblivious_simplification(self):
def test_check_simplification(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
u0, u1 = x.tolist()
torch._check_is_size(u0)
torch._check_is_size(u1)
torch._check((2 * u0) % (u0 + u1) == 0)
torch._check((2 * u0) // (u0 + u1) != 0)
if guard_size_oblivious((2 * u0) // (u0 + u1) == 0):
if (2 * u0) // (u0 + u1) == 0:
return torch.tensor(True)
else:
return torch.tensor(False)

View File

@ -1692,6 +1692,7 @@ def _check(cond, message=None): # noqa: F811
_check_with(RuntimeError, cond, message) # pyrefly: ignore # bad-argument-type
# TODO add deprecation annotation
def _check_is_size(i, message=None, *, max=None):
"""Checks that a given integer is a valid size (i.e., is non-negative).
You should use this over ``_check(i >= 0)`` because it can prevent