pytorch/test/torch_np/test_random.py
Evgeni Burovski 92c49e2168 MAINT/TST: pytorch-ify torch._numpy tests (added tests only, not vendored) (#109593)
1. Inherit from TestCase
2. Use pytorch parametrization
3. Use unittest.expectedFailure to mark xfails

All this to make pytest-less invocation work:

$ python test/torch_np/test_basic.py

Furthermor, tests can now be run under dynamo, and we see first errors:

```
$ PYTORCH_TEST_WITH_DYNAMO=1 python test/torch_np/test_basic.py -k test_toscalar_list_func
.E.
======================================================================
ERROR: test_toscalar_list_func_<function shape at 0x7f9b83a4fc10>_np_func_<function shape at 0x7f9a8dd38af0> (__main__.TestOneArrToScalar)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/ev-br/repos/pytorch/torch/testing/_internal/common_utils.py", line 356, in instantiated_test
    test(self, **param_kwargs)
  File "test/torch_np/test_basic.py", line 232, in test_toscalar_list
    @parametrize("func, np_func", one_arg_scalar_funcs)
  File "/home/ev-br/repos/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ev-br/repos/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ev-br/repos/pytorch/torch/_dynamo/eval_frame.py", line 406, in _fn
    return fn(*args, **kwargs)
  File "/home/ev-br/repos/pytorch/torch/fx/graph_module.py", line 726, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/ev-br/repos/pytorch/torch/fx/graph_module.py", line 305, in __call__
    raise e
  File "/home/ev-br/repos/pytorch/torch/fx/graph_module.py", line 292, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/ev-br/repos/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ev-br/repos/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.2", line 5, in forward
    shape = torch._numpy._funcs_impl.shape([[1, 2, 3], [4, 5, 6]])
  File "/home/ev-br/repos/pytorch/torch/_numpy/_funcs_impl.py", line 655, in shape
    return tuple(a.shape)
AttributeError: 'list' object has no attribute 'shape'

----------------------------------------------------------------------
Ran 3 tests in 0.915s

FAILED (errors=1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109593
Approved by: https://github.com/lezcano
2023-09-23 18:18:50 +00:00

150 lines
4.3 KiB
Python

# Owner(s): ["module: dynamo"]
"""Light smoke test switching between numpy to pytorch random streams.
"""
from contextlib import contextmanager
from functools import partial
import numpy as _np
import pytest
import torch._dynamo.config as config
import torch._numpy as tnp
from torch._numpy.testing import assert_equal
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
subtest,
TestCase,
)
@contextmanager
def control_stream(use_numpy=False):
oldstate = config.use_numpy_random_stream
config.use_numpy_random_stream = use_numpy
try:
yield
finally:
config.use_numpy_random_stream = oldstate
@instantiate_parametrized_tests
class TestScalarReturn(TestCase):
@parametrize("use_numpy", [True, False])
@parametrize(
"func",
[
tnp.random.normal,
tnp.random.rand,
partial(tnp.random.randint, 0, 5),
tnp.random.randn,
subtest(tnp.random.random, name="random_random"),
subtest(tnp.random.random_sample, name="random_sample"),
tnp.random.sample,
tnp.random.uniform,
],
)
def test_rndm_scalar(self, func, use_numpy):
# default `size` means a python scalar return
with control_stream(use_numpy):
r = func()
assert isinstance(r, (int, float))
@parametrize("use_numpy", [True, False])
@parametrize(
"func",
[
tnp.random.normal,
tnp.random.rand,
partial(tnp.random.randint, 0, 5),
tnp.random.randn,
subtest(tnp.random.random, name="random_random"),
subtest(tnp.random.random_sample, name="random_sample"),
tnp.random.sample,
tnp.random.uniform,
],
)
def test_rndm_array(self, func, use_numpy):
with control_stream(use_numpy):
if func in (tnp.random.rand, tnp.random.randn):
r = func(10)
else:
r = func(size=10)
assert isinstance(r, tnp.ndarray)
@instantiate_parametrized_tests
class TestShuffle(TestCase):
@parametrize("use_numpy", [True, False])
def test_1d(self, use_numpy):
ax = tnp.asarray([1, 2, 3, 4, 5, 6])
ox = ax.copy()
tnp.random.seed(1234)
tnp.random.shuffle(ax)
assert isinstance(ax, tnp.ndarray)
assert not (ax == ox).all()
@parametrize("use_numpy", [True, False])
def test_2d(self, use_numpy):
# np.shuffle only shuffles the first axis
ax = tnp.asarray([[1, 2, 3], [4, 5, 6]])
ox = ax.copy()
tnp.random.seed(1234)
tnp.random.shuffle(ax)
assert isinstance(ax, tnp.ndarray)
assert not (ax == ox).all()
@parametrize("use_numpy", [True, False])
def test_shuffle_list(self, use_numpy):
# on eager, we refuse to shuffle lists
# under dynamo, we always fall back to numpy
# NB: this means that the random stream is different for
# shuffling a list or an array when USE_NUMPY_STREAM == False
x = [1, 2, 3]
with pytest.raises(NotImplementedError):
tnp.random.shuffle(x)
@instantiate_parametrized_tests
class TestChoice(TestCase):
@parametrize("use_numpy", [True, False])
def test_choice(self, use_numpy):
kwds = dict(size=3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
with control_stream(use_numpy):
tnp.random.seed(12345)
x = tnp.random.choice(5, **kwds)
tnp.random.seed(12345)
x_1 = tnp.random.choice(tnp.arange(5), **kwds)
assert_equal(x, x_1)
class TestNumpyGlobal(TestCase):
def test_numpy_global(self):
with control_stream(use_numpy=True):
tnp.random.seed(12345)
x = tnp.random.uniform(0, 1, size=11)
# check that the stream is identical to numpy's
_np.random.seed(12345)
x_np = _np.random.uniform(0, 1, size=11)
assert_equal(x, tnp.asarray(x_np))
# switch to the pytorch stream, variates differ
with control_stream(use_numpy=False):
tnp.random.seed(12345)
x_1 = tnp.random.uniform(0, 1, size=11)
assert not (x_1 == x).all()
if __name__ == "__main__":
run_tests()