pytorch/test/torch_np/test_random.py
lezcano a9dca53438 NumPy support in torch.compile (#106211)
RFC: https://github.com/pytorch/rfcs/pull/54
First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/

We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core.

In the next commits, I do a number of things in this order
- Fix a few small issues
- Make the tests that this PR adds pass
- Bend backwards until lintrunner passes
- Remove the optional dependency on `torch_np` and simply rely on the upstreamed code
- Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now.

Missing from this PR (but not blocking):
- Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate.
- https://github.com/pytorch/pytorch/pull/106431#issuecomment-1667079543. @voznesenskym to submit a fix after we merge.

All the tests in `tests/torch_np` take about 75s to run.

This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106211
Approved by: https://github.com/ezyang
2023-08-11 00:39:32 +00:00

51 lines
1.1 KiB
Python

# Owner(s): ["module: dynamo"]
"""Light smoke test switching between numpy to pytorch random streams.
"""
import pytest
import torch._numpy as tnp
from torch._numpy.testing import assert_equal
def test_uniform():
r = tnp.random.uniform(0, 1, size=10)
def test_shuffle():
x = tnp.arange(10)
tnp.random.shuffle(x)
def test_numpy_global():
tnp.random.USE_NUMPY_RANDOM = True
tnp.random.seed(12345)
x = tnp.random.uniform(0, 1, size=11)
# check that the stream is identical to numpy's
import numpy as _np
_np.random.seed(12345)
x_np = _np.random.uniform(0, 1, size=11)
assert_equal(x, tnp.asarray(x_np))
# switch to the pytorch stream, variates differ
tnp.random.USE_NUMPY_RANDOM = False
tnp.random.seed(12345)
x_1 = tnp.random.uniform(0, 1, size=11)
assert not (x_1 == x).all()
def test_wrong_global():
try:
oldstate = tnp.random.USE_NUMPY_RANDOM
tnp.random.USE_NUMPY_RANDOM = "oops"
with pytest.raises(ValueError):
tnp.random.rand()
finally:
tnp.random.USE_NUMPY_RANDOM = oldstate