mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[BE] @serialTest decorator must be called (#157388)
Otherwise it turns test into a trivial one(that always succeeds), as following example demonstrates
```python
import torch
from torch.testing._internal.common_utils import serialTest, run_tests, TestCase
class MegaTest(TestCase):
@serialTest
def test_foo(self):
if hasattr(self.test_foo, "pytestmark"):
print("foo has attr and it is", self.test_foo.pytestmark)
print("foo")
@serialTest()
def test_bar(self):
if hasattr(self.test_bar, "pytestmark"):
print("bar has attr and it is", self.test_bar.pytestmark)
print("bar")
if __name__ == "__main__":
run_tests()
```
That will print
```
test_bar (__main__.MegaTest.test_bar) ... bar has attr and it is [Mark(name='serial', args=(), kwargs={})]
bar
ok
test_foo (__main__.MegaTest.test_foo) ... ok
----------------------------------------------------------------------
Ran 2 tests in 0.013s
```
Added assert that arg is boolean in the decorator to prevent such silent skips in the future
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157388
Approved by: https://github.com/clee2000
This commit is contained in:
parent
eaf32fffb7
commit
5e636d664a
|
|
@ -3939,7 +3939,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
opt_model(17, (12,), out2)
|
||||
|
||||
@requires_cuda
|
||||
@serialTest
|
||||
@serialTest()
|
||||
def test_mem_leak_guards(self):
|
||||
def gn(x0, x):
|
||||
return x0 * x
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ class TestCuda(TestCase):
|
|||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
@serialTest
|
||||
@serialTest()
|
||||
def test_host_memory_stats(self):
|
||||
# Helper functions
|
||||
def empty_stats():
|
||||
|
|
@ -4292,7 +4292,7 @@ class TestCudaMallocAsync(TestCase):
|
|||
finally:
|
||||
torch.cuda.memory._record_memory_history(None)
|
||||
|
||||
@serialTest
|
||||
@serialTest()
|
||||
def test_max_split_expandable(self):
|
||||
try:
|
||||
torch.cuda.memory.empty_cache()
|
||||
|
|
@ -4328,7 +4328,7 @@ class TestCudaMallocAsync(TestCase):
|
|||
finally:
|
||||
torch.cuda.memory.set_per_process_memory_fraction(orig)
|
||||
|
||||
@serialTest
|
||||
@serialTest()
|
||||
def test_garbage_collect_expandable(self):
|
||||
try:
|
||||
torch.cuda.memory.empty_cache()
|
||||
|
|
|
|||
|
|
@ -9336,7 +9336,7 @@ class TestSDPA(TestCaseMPS):
|
|||
)
|
||||
self._compare_tensors(y.cpu(), y_ref)
|
||||
|
||||
@serialTest
|
||||
@serialTest()
|
||||
def test_sdpa_fp32_no_memory_leak(self):
|
||||
def get_mps_memory_usage():
|
||||
return (torch.mps.current_allocated_memory() / (1024 * 1024),
|
||||
|
|
|
|||
|
|
@ -1696,6 +1696,10 @@ def serialTest(condition=True):
|
|||
"""
|
||||
Decorator for running tests serially. Requires pytest
|
||||
"""
|
||||
# If one apply decorator directly condition will be callable
|
||||
# And test will essentially be essentially skipped, which is undesirable
|
||||
assert type(condition) is bool
|
||||
|
||||
def decorator(fn):
|
||||
if has_pytest and condition:
|
||||
return pytest.mark.serial(fn)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user