mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add inductor test for torchbind symint (#149980)
Summary: add test Test Plan: ``` buck run //caffe2/test:test_export -- -r test_compile_custom_obj_unbacked_symint ``` Differential Revision: D71843179 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149980 Approved by: https://github.com/BoyuanFeng
This commit is contained in:
parent
a0253d2840
commit
b2088f1afe
|
|
@ -79,6 +79,22 @@ class TestTorchbind(TestCase):
|
|||
new_res = compiled(*inputs)
|
||||
self.assertTrue(torch.allclose(orig_res, new_res))
|
||||
|
||||
def test_torchbind_compile_symint(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(2, 3)
|
||||
|
||||
def forward(self, x):
|
||||
a = torch.ops._TorchScriptTesting.takes_foo_tensor_return(self.attr, x)
|
||||
return a
|
||||
|
||||
m = M()
|
||||
inputs = (torch.ones(2, 3),)
|
||||
orig_res = m(*inputs)
|
||||
new_res = torch.compile(m, backend="inductor")(*inputs)
|
||||
self.assertTrue(torch.allclose(orig_res, new_res))
|
||||
|
||||
def test_torchbind_compile(self):
|
||||
_, inputs, orig_res, mod = self.get_exported_model()
|
||||
new_res = torch.compile(mod, backend="inductor")(*inputs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user