mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[Inductor] Add input value checking to randint meta function (#147191)
Fixes #147070 Adding value checking for the range to the meta function, similar to which in the CUDA/CPU aten op. Test with ``` PYTORCH_TEST_WITH_DYNAMO=1 pytest test/test_tensor_creation_ops.py -k test_randint_inference ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/147191 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
This commit is contained in:
parent
c644f4c5fe
commit
dacdc9782b
|
|
@ -3487,6 +3487,22 @@ class TestRandomTensorCreation(TestCase):
|
|||
self.assertIs(torch.int64, torch.randint(*args, size=size, out=out).dtype)
|
||||
self.assertIs(torch.int64, torch.randint(*args, size=size, out=out, dtype=torch.int64).dtype)
|
||||
|
||||
self.assertRaisesRegex(RuntimeError,
|
||||
"random_ expects 'from' to be less than 'to', but got from=0 >= to=0",
|
||||
lambda: torch.randint(0, size=size))
|
||||
self.assertRaisesRegex(RuntimeError,
|
||||
"random_ expects 'from' to be less than 'to', but got from=-1 >= to=-2",
|
||||
lambda: torch.randint(-1, -2, size=size))
|
||||
self.assertRaisesRegex(TypeError,
|
||||
r"randint\(\): argument 'high' \(position 1\) must be int, not float",
|
||||
lambda: torch.randint(.5, size=size))
|
||||
self.assertRaisesRegex(RuntimeError,
|
||||
"from is out of bounds for",
|
||||
lambda: torch.randint(-32769, 0, size=size, dtype=torch.int16))
|
||||
self.assertRaisesRegex(RuntimeError,
|
||||
"from is out of bounds for",
|
||||
lambda: torch.randint(-1, 1, size=size, dtype=torch.uint32))
|
||||
|
||||
# TODO: this test should be updated
|
||||
@onlyCPU
|
||||
def test_randint(self, device):
|
||||
|
|
|
|||
|
|
@ -411,6 +411,11 @@ def meta_randint(
|
|||
device=None,
|
||||
pin_memory=None,
|
||||
):
|
||||
low = 0
|
||||
torch._check(
|
||||
high > low,
|
||||
lambda: f"random_ expects 'from' to be less than 'to', but got from={low} >= to={high}",
|
||||
)
|
||||
return torch.empty(
|
||||
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
||||
)
|
||||
|
|
@ -428,6 +433,10 @@ def meta_randint_low(
|
|||
device=None,
|
||||
pin_memory=None,
|
||||
):
|
||||
torch._check(
|
||||
high > low,
|
||||
lambda: f"random_ expects 'from' to be less than 'to', but got from={low} >= to={high}",
|
||||
)
|
||||
return torch.empty(
|
||||
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user