pytorch/test/custom_operator/test_custom_ops.py
rzou e309d6fa1c Better unsupported op error message (#117770)
Previously, if someone wrote a python abstract impl but didn't import
the module it is in, then we would raise an error message suggesting
that the user needs to add an abstract impl for the operator.

In addition to this, we suggest that the user try importing the module
associated with the operator in the pystub (it's not guaranteed that
an abstract impl does exist) to avoid confusion.

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117770
Approved by: https://github.com/ydwu4, https://github.com/williamwen42
2024-01-23 15:05:16 +00:00

144 lines
5.2 KiB
Python

# Owner(s): ["module: unknown"]
import os.path
import sys
import tempfile
import torch
from torch import ops
from model import Model, get_custom_op_library_path
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
import unittest
torch.ops.import_module("pointwise")
class TestCustomOperators(TestCase):
def setUp(self):
self.library_path = get_custom_op_library_path()
ops.load_library(self.library_path)
def test_custom_library_is_loaded(self):
self.assertIn(self.library_path, ops.loaded_libraries)
def test_op_with_no_abstract_impl_pystub(self):
x = torch.randn(3, device='meta')
with self.assertRaisesRegex(RuntimeError, "pointwise"):
torch.ops.custom.tan(x)
def test_op_with_incorrect_abstract_impl_pystub(self):
x = torch.randn(3, device='meta')
with self.assertRaisesRegex(RuntimeError, "pointwise"):
torch.ops.custom.cos(x)
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
def test_dynamo_pystub_suggestion(self):
x = torch.randn(3)
@torch.compile(backend="eager", fullgraph=True)
def f(x):
return torch.ops.custom.asin(x)
with self.assertRaisesRegex(RuntimeError, r'unsupported operator: .* \(you may need to `import nonexistent`'):
f(x)
def test_abstract_impl_pystub_faketensor(self):
from functorch import make_fx
x = torch.randn(3, device='cpu')
self.assertNotIn("my_custom_ops", sys.modules.keys())
with self.assertRaises(torch._subclasses.fake_tensor.UnsupportedOperatorException):
gm = make_fx(torch.ops.custom.nonzero.default, tracing_mode="symbolic")(x)
torch.ops.import_module("my_custom_ops")
gm = make_fx(torch.ops.custom.nonzero.default, tracing_mode="symbolic")(x)
self.assertExpectedInline("""\
def forward(self, arg0_1):
nonzero = torch.ops.custom.nonzero.default(arg0_1); arg0_1 = None
return nonzero
""".strip(), gm.code.strip())
def test_abstract_impl_pystub_meta(self):
x = torch.randn(3, device="meta")
self.assertNotIn("my_custom_ops2", sys.modules.keys())
with self.assertRaisesRegex(NotImplementedError, r"import the 'my_custom_ops2'"):
y = torch.ops.custom.sin.default(x)
torch.ops.import_module("my_custom_ops2")
y = torch.ops.custom.sin.default(x)
def test_calling_custom_op_string(self):
output = ops.custom.op2("abc", "def")
self.assertLess(output, 0)
output = ops.custom.op2("abc", "abc")
self.assertEqual(output, 0)
def test_calling_custom_op(self):
output = ops.custom.op(torch.ones(5), 2.0, 3)
self.assertEqual(type(output), list)
self.assertEqual(len(output), 3)
for tensor in output:
self.assertTrue(tensor.allclose(torch.ones(5) * 2))
output = ops.custom.op_with_defaults(torch.ones(5))
self.assertEqual(type(output), list)
self.assertEqual(len(output), 1)
self.assertTrue(output[0].allclose(torch.ones(5)))
def test_calling_custom_op_with_autograd(self):
x = torch.randn((5, 5), requires_grad=True)
y = torch.randn((5, 5), requires_grad=True)
output = ops.custom.op_with_autograd(x, 2, y)
self.assertTrue(output.allclose(x + 2 * y + x * y))
go = torch.ones((), requires_grad=True)
output.sum().backward(go, False, True)
grad = torch.ones(5, 5)
self.assertEqual(x.grad, y + grad)
self.assertEqual(y.grad, x + grad * 2)
# Test with optional arg.
x.grad.zero_()
y.grad.zero_()
z = torch.randn((5, 5), requires_grad=True)
output = ops.custom.op_with_autograd(x, 2, y, z)
self.assertTrue(output.allclose(x + 2 * y + x * y + z))
go = torch.ones((), requires_grad=True)
output.sum().backward(go, False, True)
self.assertEqual(x.grad, y + grad)
self.assertEqual(y.grad, x + grad * 2)
self.assertEqual(z.grad, grad)
def test_calling_custom_op_with_autograd_in_nograd_mode(self):
with torch.no_grad():
x = torch.randn((5, 5), requires_grad=True)
y = torch.randn((5, 5), requires_grad=True)
output = ops.custom.op_with_autograd(x, 2, y)
self.assertTrue(output.allclose(x + 2 * y + x * y))
def test_calling_custom_op_inside_script_module(self):
model = Model()
output = model.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
def test_saving_and_loading_script_module_with_custom_op(self):
model = Model()
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually.
file = tempfile.NamedTemporaryFile(delete=False)
try:
file.close()
model.save(file.name)
loaded = torch.jit.load(file.name)
finally:
os.unlink(file.name)
output = loaded.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
if __name__ == "__main__":
run_tests()