mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Marginally improve pytest collection for top-level test files (#53617)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53617 I'm trying to make `pytest test/*.py` work--right now, it fails during test collection. This removes a few of the easier to fix pytest collection problems one way or another. I have two remaining problems which is that the default dtype is trashed on entry to test_torch.py and test_cuda.py, I'll try to fix those in a follow up. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D26918377 Pulled By: ezyang fbshipit-source-id: 42069786882657e1e3ee974acb3ec48115f16210
This commit is contained in:
parent
6e020a4844
commit
70733f2e67
|
|
@ -8,10 +8,14 @@ import torch
|
||||||
import torch.backends.cudnn
|
import torch.backends.cudnn
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# TODO: Rewrite these tests so that they can be collected via pytest without
|
||||||
|
# using run_test.py
|
||||||
try:
|
try:
|
||||||
import torch_test_cpp_extension.cpp as cpp_extension
|
cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp")
|
||||||
import torch_test_cpp_extension.msnpu as msnpu_extension
|
msnpu_extension = pytest.importorskip("torch_test_cpp_extension.msnpu")
|
||||||
import torch_test_cpp_extension.rng as rng_extension
|
rng_extension = pytest.importorskip("torch_test_cpp_extension.rng")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"test_cpp_extensions_aot.py cannot be invoked directly. Run "
|
"test_cpp_extensions_aot.py cannot be invoked directly. Run "
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,6 @@ class TestOpenMP_ParallelFor(TestCase):
|
||||||
side_dim = 80
|
side_dim = 80
|
||||||
x = torch.randn([batch, channels, side_dim, side_dim], device=device)
|
x = torch.randn([batch, channels, side_dim, side_dim], device=device)
|
||||||
model = Network()
|
model = Network()
|
||||||
ncores = min(5, psutil.cpu_count(logical=False))
|
|
||||||
|
|
||||||
def func(self, runs):
|
def func(self, runs):
|
||||||
p = psutil.Process()
|
p = psutil.Process()
|
||||||
|
|
@ -61,7 +60,8 @@ class TestOpenMP_ParallelFor(TestCase):
|
||||||
def test_n_threads(self):
|
def test_n_threads(self):
|
||||||
"""Make sure there is no memory leak with many threads
|
"""Make sure there is no memory leak with many threads
|
||||||
"""
|
"""
|
||||||
torch.set_num_threads(self.ncores)
|
ncores = min(5, psutil.cpu_count(logical=False))
|
||||||
|
torch.set_num_threads(ncores)
|
||||||
self.func_rss(300)
|
self.func_rss(300)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
||||||
|
|
@ -315,9 +315,16 @@ def generate_tensor_like_torch_implementations():
|
||||||
torch_vars = vars(torch)
|
torch_vars = vars(torch)
|
||||||
untested_funcs = []
|
untested_funcs = []
|
||||||
testing_overrides = get_testing_overrides()
|
testing_overrides = get_testing_overrides()
|
||||||
|
# test/test_cpp_api_parity.py monkeypatches torch.nn to have a new
|
||||||
|
# function sample_functional. Depending on what order you run pytest
|
||||||
|
# collection, this may trigger the error here. This is a hack to fix
|
||||||
|
# the problem. A more proper fix is to make the "not tested" check
|
||||||
|
# a test on its own, and to make sure the monkeypatch is only installed
|
||||||
|
# for the span of the relevant test (and deleted afterwards)
|
||||||
|
testing_ignore = {"sample_functional"}
|
||||||
for namespace, funcs in get_overridable_functions().items():
|
for namespace, funcs in get_overridable_functions().items():
|
||||||
for func in funcs:
|
for func in funcs:
|
||||||
if func not in testing_overrides:
|
if func not in testing_overrides and func.__name__ not in testing_ignore:
|
||||||
untested_funcs.append("{}.{}".format(namespace, func.__name__))
|
untested_funcs.append("{}.{}".format(namespace, func.__name__))
|
||||||
msg = (
|
msg = (
|
||||||
"The following functions are not tested for __torch_function__ "
|
"The following functions are not tested for __torch_function__ "
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user