mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Previously, if someone wrote a python abstract impl but didn't import the module it is in, then we would raise an error message suggesting that the user needs to add an abstract impl for the operator. In addition to this, we suggest that the user try importing the module associated with the operator in the pystub (it's not guaranteed that an abstract impl does exist) to avoid confusion. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/117770 Approved by: https://github.com/ydwu4, https://github.com/williamwen42
144 lines
5.2 KiB
Python
144 lines
5.2 KiB
Python
# Owner(s): ["module: unknown"]
|
|
|
|
import os.path
|
|
import sys
|
|
import tempfile
|
|
|
|
import torch
|
|
from torch import ops
|
|
|
|
from model import Model, get_custom_op_library_path
|
|
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
|
|
import unittest
|
|
|
|
torch.ops.import_module("pointwise")
|
|
|
|
class TestCustomOperators(TestCase):
|
|
def setUp(self):
|
|
self.library_path = get_custom_op_library_path()
|
|
ops.load_library(self.library_path)
|
|
|
|
def test_custom_library_is_loaded(self):
|
|
self.assertIn(self.library_path, ops.loaded_libraries)
|
|
|
|
def test_op_with_no_abstract_impl_pystub(self):
|
|
x = torch.randn(3, device='meta')
|
|
with self.assertRaisesRegex(RuntimeError, "pointwise"):
|
|
torch.ops.custom.tan(x)
|
|
|
|
def test_op_with_incorrect_abstract_impl_pystub(self):
|
|
x = torch.randn(3, device='meta')
|
|
with self.assertRaisesRegex(RuntimeError, "pointwise"):
|
|
torch.ops.custom.cos(x)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
|
def test_dynamo_pystub_suggestion(self):
|
|
x = torch.randn(3)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f(x):
|
|
return torch.ops.custom.asin(x)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r'unsupported operator: .* \(you may need to `import nonexistent`'):
|
|
f(x)
|
|
|
|
def test_abstract_impl_pystub_faketensor(self):
|
|
from functorch import make_fx
|
|
x = torch.randn(3, device='cpu')
|
|
self.assertNotIn("my_custom_ops", sys.modules.keys())
|
|
|
|
with self.assertRaises(torch._subclasses.fake_tensor.UnsupportedOperatorException):
|
|
gm = make_fx(torch.ops.custom.nonzero.default, tracing_mode="symbolic")(x)
|
|
|
|
torch.ops.import_module("my_custom_ops")
|
|
gm = make_fx(torch.ops.custom.nonzero.default, tracing_mode="symbolic")(x)
|
|
self.assertExpectedInline("""\
|
|
def forward(self, arg0_1):
|
|
nonzero = torch.ops.custom.nonzero.default(arg0_1); arg0_1 = None
|
|
return nonzero
|
|
""".strip(), gm.code.strip())
|
|
|
|
def test_abstract_impl_pystub_meta(self):
|
|
x = torch.randn(3, device="meta")
|
|
self.assertNotIn("my_custom_ops2", sys.modules.keys())
|
|
with self.assertRaisesRegex(NotImplementedError, r"import the 'my_custom_ops2'"):
|
|
y = torch.ops.custom.sin.default(x)
|
|
torch.ops.import_module("my_custom_ops2")
|
|
y = torch.ops.custom.sin.default(x)
|
|
|
|
def test_calling_custom_op_string(self):
|
|
output = ops.custom.op2("abc", "def")
|
|
self.assertLess(output, 0)
|
|
output = ops.custom.op2("abc", "abc")
|
|
self.assertEqual(output, 0)
|
|
|
|
def test_calling_custom_op(self):
|
|
output = ops.custom.op(torch.ones(5), 2.0, 3)
|
|
self.assertEqual(type(output), list)
|
|
self.assertEqual(len(output), 3)
|
|
for tensor in output:
|
|
self.assertTrue(tensor.allclose(torch.ones(5) * 2))
|
|
|
|
output = ops.custom.op_with_defaults(torch.ones(5))
|
|
self.assertEqual(type(output), list)
|
|
self.assertEqual(len(output), 1)
|
|
self.assertTrue(output[0].allclose(torch.ones(5)))
|
|
|
|
def test_calling_custom_op_with_autograd(self):
|
|
x = torch.randn((5, 5), requires_grad=True)
|
|
y = torch.randn((5, 5), requires_grad=True)
|
|
output = ops.custom.op_with_autograd(x, 2, y)
|
|
self.assertTrue(output.allclose(x + 2 * y + x * y))
|
|
|
|
go = torch.ones((), requires_grad=True)
|
|
output.sum().backward(go, False, True)
|
|
grad = torch.ones(5, 5)
|
|
|
|
self.assertEqual(x.grad, y + grad)
|
|
self.assertEqual(y.grad, x + grad * 2)
|
|
|
|
# Test with optional arg.
|
|
x.grad.zero_()
|
|
y.grad.zero_()
|
|
z = torch.randn((5, 5), requires_grad=True)
|
|
output = ops.custom.op_with_autograd(x, 2, y, z)
|
|
self.assertTrue(output.allclose(x + 2 * y + x * y + z))
|
|
|
|
go = torch.ones((), requires_grad=True)
|
|
output.sum().backward(go, False, True)
|
|
self.assertEqual(x.grad, y + grad)
|
|
self.assertEqual(y.grad, x + grad * 2)
|
|
self.assertEqual(z.grad, grad)
|
|
|
|
def test_calling_custom_op_with_autograd_in_nograd_mode(self):
|
|
with torch.no_grad():
|
|
x = torch.randn((5, 5), requires_grad=True)
|
|
y = torch.randn((5, 5), requires_grad=True)
|
|
output = ops.custom.op_with_autograd(x, 2, y)
|
|
self.assertTrue(output.allclose(x + 2 * y + x * y))
|
|
|
|
def test_calling_custom_op_inside_script_module(self):
|
|
model = Model()
|
|
output = model.forward(torch.ones(5))
|
|
self.assertTrue(output.allclose(torch.ones(5) + 1))
|
|
|
|
def test_saving_and_loading_script_module_with_custom_op(self):
|
|
model = Model()
|
|
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
|
|
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
|
|
# close the file after creation and try to remove it manually.
|
|
file = tempfile.NamedTemporaryFile(delete=False)
|
|
try:
|
|
file.close()
|
|
model.save(file.name)
|
|
loaded = torch.jit.load(file.name)
|
|
finally:
|
|
os.unlink(file.name)
|
|
|
|
output = loaded.forward(torch.ones(5))
|
|
self.assertTrue(output.allclose(torch.ones(5) + 1))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|