mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Marginally improve pytest collection for top-level test files (#53617)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53617 I'm trying to make `pytest test/*.py` work--right now, it fails during test collection. This removes a few of the easier to fix pytest collection problems one way or another. I have two remaining problems which is that the default dtype is trashed on entry to test_torch.py and test_cuda.py, I'll try to fix those in a follow up. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D26918377 Pulled By: ezyang fbshipit-source-id: 42069786882657e1e3ee974acb3ec48115f16210
This commit is contained in:
parent
6e020a4844
commit
70733f2e67
|
|
@ -8,10 +8,14 @@ import torch
|
|||
import torch.backends.cudnn
|
||||
import torch.utils.cpp_extension
|
||||
|
||||
import pytest
|
||||
|
||||
# TODO: Rewrite these tests so that they can be collected via pytest without
|
||||
# using run_test.py
|
||||
try:
|
||||
import torch_test_cpp_extension.cpp as cpp_extension
|
||||
import torch_test_cpp_extension.msnpu as msnpu_extension
|
||||
import torch_test_cpp_extension.rng as rng_extension
|
||||
cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp")
|
||||
msnpu_extension = pytest.importorskip("torch_test_cpp_extension.msnpu")
|
||||
rng_extension = pytest.importorskip("torch_test_cpp_extension.rng")
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"test_cpp_extensions_aot.py cannot be invoked directly. Run "
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ class TestOpenMP_ParallelFor(TestCase):
|
|||
side_dim = 80
|
||||
x = torch.randn([batch, channels, side_dim, side_dim], device=device)
|
||||
model = Network()
|
||||
ncores = min(5, psutil.cpu_count(logical=False))
|
||||
|
||||
def func(self, runs):
|
||||
p = psutil.Process()
|
||||
|
|
@ -61,7 +60,8 @@ class TestOpenMP_ParallelFor(TestCase):
|
|||
def test_n_threads(self):
|
||||
"""Make sure there is no memory leak with many threads
|
||||
"""
|
||||
torch.set_num_threads(self.ncores)
|
||||
ncores = min(5, psutil.cpu_count(logical=False))
|
||||
torch.set_num_threads(ncores)
|
||||
self.func_rss(300)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -315,9 +315,16 @@ def generate_tensor_like_torch_implementations():
|
|||
torch_vars = vars(torch)
|
||||
untested_funcs = []
|
||||
testing_overrides = get_testing_overrides()
|
||||
# test/test_cpp_api_parity.py monkeypatches torch.nn to have a new
|
||||
# function sample_functional. Depending on what order you run pytest
|
||||
# collection, this may trigger the error here. This is a hack to fix
|
||||
# the problem. A more proper fix is to make the "not tested" check
|
||||
# a test on its own, and to make sure the monkeypatch is only installed
|
||||
# for the span of the relevant test (and deleted afterwards)
|
||||
testing_ignore = {"sample_functional"}
|
||||
for namespace, funcs in get_overridable_functions().items():
|
||||
for func in funcs:
|
||||
if func not in testing_overrides:
|
||||
if func not in testing_overrides and func.__name__ not in testing_ignore:
|
||||
untested_funcs.append("{}.{}".format(namespace, func.__name__))
|
||||
msg = (
|
||||
"The following functions are not tested for __torch_function__ "
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user