pytorch/test/jit/test_warn.py
Anthony Barbier bf7e290854 Add __main__ guards to jit tests (#154725)
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs.

In jit tests:

- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725
Approved by: https://github.com/clee2000
2025-06-16 10:28:45 +00:00

148 lines
3.6 KiB
Python

# Owner(s): ["oncall: jit"]
import io
import os
import sys
import warnings
from contextlib import redirect_stderr
import torch
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
class TestWarn(JitTestCase):
def test_warn(self):
@torch.jit.script
def fn():
warnings.warn("I am warning you")
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=1, exactly=True
).run(f.getvalue())
def test_warn_only_once(self):
@torch.jit.script
def fn():
for _ in range(10):
warnings.warn("I am warning you")
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=1, exactly=True
).run(f.getvalue())
def test_warn_only_once_in_loop_func(self):
def w():
warnings.warn("I am warning you")
@torch.jit.script
def fn():
for _ in range(10):
w()
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=1, exactly=True
).run(f.getvalue())
def test_warn_once_per_func(self):
def w1():
warnings.warn("I am warning you")
def w2():
warnings.warn("I am warning you")
@torch.jit.script
def fn():
w1()
w2()
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=2, exactly=True
).run(f.getvalue())
def test_warn_once_per_func_in_loop(self):
def w1():
warnings.warn("I am warning you")
def w2():
warnings.warn("I am warning you")
@torch.jit.script
def fn():
for _ in range(10):
w1()
w2()
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=2, exactly=True
).run(f.getvalue())
def test_warn_multiple_calls_multiple_warnings(self):
@torch.jit.script
def fn():
warnings.warn("I am warning you")
f = io.StringIO()
with redirect_stderr(f):
fn()
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=2, exactly=True
).run(f.getvalue())
def test_warn_multiple_calls_same_func_diff_stack(self):
def warn(caller: str):
warnings.warn("I am warning you from " + caller)
@torch.jit.script
def foo():
warn("foo")
@torch.jit.script
def bar():
warn("bar")
f = io.StringIO()
with redirect_stderr(f):
foo()
bar()
FileCheck().check_count(
str="UserWarning: I am warning you from foo", count=1, exactly=True
).check_count(
str="UserWarning: I am warning you from bar", count=1, exactly=True
).run(
f.getvalue()
)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")