pytorch/test/custom_operator/test_custom_ops.py
Edward Yang 173f224570 Turn on F401: Unused import warning. (#18598)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598
ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18598 Turn on F401: Unused import warning.**

This was requested by someone at Facebook; this lint is turned
on for Facebook by default.  "Sure, why not."

I had to noqa a number of imports in __init__.  Hypothetically
we're supposed to use __all__ in this case, but I was too lazy
to fix it.  Left for future work.

Be careful!  flake8-2 and flake8-3 behave differently with
respect to import resolution for # type: comments.  flake8-3 will
report an import unused; flake8-2 will not.  For now, I just
noqa'd all these sites.

All the changes were done by hand.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: D14687478

fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
2019-03-30 09:01:17 -07:00

61 lines
2.0 KiB
Python

import os.path
import tempfile
import unittest
import torch
from torch import ops
from model import Model, get_custom_op_library_path
class TestCustomOperators(unittest.TestCase):
def setUp(self):
self.library_path = get_custom_op_library_path()
ops.load_library(self.library_path)
def test_custom_library_is_loaded(self):
self.assertIn(self.library_path, ops.loaded_libraries)
def test_calling_custom_op_string(self):
output = ops.custom.op2("abc", "def")
self.assertLess(output, 0)
output = ops.custom.op2("abc", "abc")
self.assertEqual(output, 0)
def test_calling_custom_op(self):
output = ops.custom.op(torch.ones(5), 2.0, 3)
self.assertEqual(type(output), list)
self.assertEqual(len(output), 3)
for tensor in output:
self.assertTrue(tensor.allclose(torch.ones(5) * 2))
output = ops.custom.op_with_defaults(torch.ones(5))
self.assertEqual(type(output), list)
self.assertEqual(len(output), 1)
self.assertTrue(output[0].allclose(torch.ones(5)))
def test_calling_custom_op_inside_script_module(self):
model = Model()
output = model.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
def test_saving_and_loading_script_module_with_custom_op(self):
model = Model()
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually.
file = tempfile.NamedTemporaryFile(delete=False)
try:
file.close()
model.save(file.name)
loaded = torch.jit.load(file.name)
finally:
os.unlink(file.name)
output = loaded.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
if __name__ == "__main__":
unittest.main()