pytorch/test/custom_operator/test_custom_ops.py
Peter Goldsborough a0d4106c07 Integrate custom op tests with CI (#10611)
Summary:
This PR is stacked on https://github.com/pytorch/pytorch/pull/10610, and only adds changes in one file `.jenkins/pytorch/test.sh`, where we now build the custom op tests and run them.

I'd also like to take this PR to discuss whether the [`TorchConfig.cmake`](https://github.com/pytorch/pytorch/blob/master/cmake/TorchConfig.cmake.in) I made is robust enough (we will also see in the CI) orionr Yangqing dzhulgakov what do you think?

Also ezyang for CI changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10611

Differential Revision: D9597627

Pulled By: goldsborough

fbshipit-source-id: f5af8164c076894f448cef7e5b356a6b3159f8b3
2018-09-10 15:40:21 -07:00

55 lines
1.8 KiB
Python

import argparse
import os.path
import tempfile
import unittest
import torch
from model import Model, get_custom_op_library_path
class TestCustomOperators(unittest.TestCase):
def setUp(self):
self.library_path = get_custom_op_library_path()
torch.ops.load_library(self.library_path)
def test_custom_library_is_loaded(self):
self.assertIn(self.library_path, torch.ops.loaded_libraries)
def test_calling_custom_op(self):
output = torch.ops.custom.op(torch.ones(5), 2.0, 3)
self.assertEqual(type(output), list)
self.assertEqual(len(output), 3)
for tensor in output:
self.assertTrue(tensor.allclose(torch.ones(5) * 2))
output = torch.ops.custom.op_with_defaults(torch.ones(5))
self.assertEqual(type(output), list)
self.assertEqual(len(output), 1)
self.assertTrue(output[0].allclose(torch.ones(5)))
def test_calling_custom_op_inside_script_module(self):
model = Model()
output = model.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
def test_saving_and_loading_script_module_with_custom_op(self):
model = Model()
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually.
file = tempfile.NamedTemporaryFile(delete=False)
try:
file.close()
model.save(file.name)
loaded = torch.jit.load(file.name)
finally:
os.unlink(file.name)
output = loaded.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
if __name__ == "__main__":
unittest.main()