mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
remove check_is_size from test_misc.py (#164667)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164667 Approved by: https://github.com/angelayi ghstack dependencies: #164664, #164665
This commit is contained in:
parent
cdaaf3e4a3
commit
ef7e2ca77e
|
|
@ -1253,7 +1253,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||
def test_bound_shape_checks(self):
|
||||
def f1(x, y):
|
||||
b = x.item()
|
||||
torch._check_is_size(b)
|
||||
torch._check(b >= 0)
|
||||
torch._check(b < y.shape[0])
|
||||
return y[:b]
|
||||
|
||||
|
|
@ -1276,7 +1276,6 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||
@torch.compile(fullgraph=True)
|
||||
def f(x):
|
||||
y = x.item()
|
||||
torch._check_is_size(y)
|
||||
r = torch.arange(y, dtype=torch.float32)
|
||||
|
||||
if r.size(0) == y:
|
||||
|
|
@ -1323,13 +1322,13 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
# Translation validation changes the exception type, don't run with it
|
||||
@torch.fx.experimental._config.patch(translation_validation=False)
|
||||
def test_torch_check_is_size(self):
|
||||
def test_torch_check_nonnegative(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f(x):
|
||||
y = x.item()
|
||||
torch._check_is_size(y)
|
||||
torch._check(y >= 0)
|
||||
# Cannot conditional on unbacked SymInt
|
||||
if y == 0:
|
||||
assert False
|
||||
|
|
@ -8071,7 +8070,6 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||
@torch.compile(fullgraph=True)
|
||||
def f(x):
|
||||
u0, u1 = x.tolist()
|
||||
torch._check_is_size(u0)
|
||||
# The condition should fold to true.
|
||||
if ((u0 + 10) * (u0 + 10)) % (u0 + 10) == 0:
|
||||
return torch.tensor(True)
|
||||
|
|
@ -9226,14 +9224,10 @@ def ___make_guard_fn():
|
|||
@torch._dynamo.config.patch(
|
||||
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
|
||||
)
|
||||
def test_unbacked_symint(self):
|
||||
def test_unbacked_symint_split(self):
|
||||
@torch.compile(backend="eager")
|
||||
def f(lengths, values):
|
||||
sizes = lengths.tolist()
|
||||
for s in sizes:
|
||||
torch._check_is_size(s)
|
||||
torch._check(s >= 2)
|
||||
torch._check(s <= 100)
|
||||
return torch.split(values, sizes)
|
||||
|
||||
f(torch.tensor([2, 3, 4]), torch.randn(9))
|
||||
|
|
@ -11319,15 +11313,12 @@ fn
|
|||
self.assertEqual(len(c2), 0)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_guard_size_oblivious_simplification(self):
|
||||
def test_check_simplification(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
u0, u1 = x.tolist()
|
||||
torch._check_is_size(u0)
|
||||
torch._check_is_size(u1)
|
||||
torch._check((2 * u0) % (u0 + u1) == 0)
|
||||
torch._check((2 * u0) // (u0 + u1) != 0)
|
||||
if guard_size_oblivious((2 * u0) // (u0 + u1) == 0):
|
||||
if (2 * u0) // (u0 + u1) == 0:
|
||||
return torch.tensor(True)
|
||||
else:
|
||||
return torch.tensor(False)
|
||||
|
|
|
|||
|
|
@ -1692,6 +1692,7 @@ def _check(cond, message=None): # noqa: F811
|
|||
_check_with(RuntimeError, cond, message) # pyrefly: ignore # bad-argument-type
|
||||
|
||||
|
||||
# TODO add deprecation annotation
|
||||
def _check_is_size(i, message=None, *, max=None):
|
||||
"""Checks that a given integer is a valid size (i.e., is non-negative).
|
||||
You should use this over ``_check(i >= 0)`` because it can prevent
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user