mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
- To test whether or not a multiline string matches some expected value, you can use assertExpected. This tests that the string matches the content stored at a file based on the name of the test (and an optional subname parameter you can pass if you what to assertExpected multiple times.) - Suppose you make a change that modifies the output in a big way. Instead of manually going through and updating each test, you instead run python test/test_jit.py --accept. This updates all of the expected outputs. You can now review them one-by-one and make sure your changes make sense. We can add more features later (e.g., munging the output to make it more stable, more sanity checking) but this is just to get us started testing. One thing to watch out for is that accept tests on intermediate representation can be a bit wobbly: it is *extremely* important that people be able to read the IR. It may be worth introducing niceties to the printer in order to ensure this is the case. Signed-off-by: Edward Z. Yang <ezyang@fb.com>
28 lines
767 B
Python
28 lines
767 B
Python
import torch
|
|
import torch.jit
|
|
from torch.autograd import Variable
|
|
from common import TestCase, run_tests
|
|
|
|
|
|
class TestJit(TestCase):
|
|
def test_simple(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
torch._C._tracer_enter((x, y))
|
|
z = torch.sigmoid(torch.tanh(x * (x + y)))
|
|
trace = torch._C._tracer_exit((z,))
|
|
|
|
# TODO: Do something more automated here
|
|
self.assertExpected(str(trace))
|
|
return
|
|
|
|
# Re-enable this when the interpreter is back
|
|
zs = z._execution_engine.run_forward(trace, (x, y))
|
|
self.assertEqual(z, zs)
|
|
|
|
# TODO: test that backwards works correctly
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|