[Inductor] Add input value checking to randint meta function (#147191)

Fixes #147070

Adding value checking for the range to the meta function, similar to which in the CUDA/CPU aten op.

Test with
```
PYTORCH_TEST_WITH_DYNAMO=1 pytest test/test_tensor_creation_ops.py -k test_randint_inference
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147191
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
Ding, Yi1 2025-02-25 02:18:16 +00:00 committed by PyTorch MergeBot
parent c644f4c5fe
commit dacdc9782b
2 changed files with 25 additions and 0 deletions

View File

@ -3487,6 +3487,22 @@ class TestRandomTensorCreation(TestCase):
self.assertIs(torch.int64, torch.randint(*args, size=size, out=out).dtype) self.assertIs(torch.int64, torch.randint(*args, size=size, out=out).dtype)
self.assertIs(torch.int64, torch.randint(*args, size=size, out=out, dtype=torch.int64).dtype) self.assertIs(torch.int64, torch.randint(*args, size=size, out=out, dtype=torch.int64).dtype)
self.assertRaisesRegex(RuntimeError,
"random_ expects 'from' to be less than 'to', but got from=0 >= to=0",
lambda: torch.randint(0, size=size))
self.assertRaisesRegex(RuntimeError,
"random_ expects 'from' to be less than 'to', but got from=-1 >= to=-2",
lambda: torch.randint(-1, -2, size=size))
self.assertRaisesRegex(TypeError,
r"randint\(\): argument 'high' \(position 1\) must be int, not float",
lambda: torch.randint(.5, size=size))
self.assertRaisesRegex(RuntimeError,
"from is out of bounds for",
lambda: torch.randint(-32769, 0, size=size, dtype=torch.int16))
self.assertRaisesRegex(RuntimeError,
"from is out of bounds for",
lambda: torch.randint(-1, 1, size=size, dtype=torch.uint32))
# TODO: this test should be updated # TODO: this test should be updated
@onlyCPU @onlyCPU
def test_randint(self, device): def test_randint(self, device):

View File

@ -411,6 +411,11 @@ def meta_randint(
device=None, device=None,
pin_memory=None, pin_memory=None,
): ):
low = 0
torch._check(
high > low,
lambda: f"random_ expects 'from' to be less than 'to', but got from={low} >= to={high}",
)
return torch.empty( return torch.empty(
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
) )
@ -428,6 +433,10 @@ def meta_randint_low(
device=None, device=None,
pin_memory=None, pin_memory=None,
): ):
torch._check(
high > low,
lambda: f"random_ expects 'from' to be less than 'to', but got from={low} >= to={high}",
)
return torch.empty( return torch.empty(
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
) )