mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add inductor test for torchbind symint (#149980)
Summary: add test Test Plan: ``` buck run //caffe2/test:test_export -- -r test_compile_custom_obj_unbacked_symint ``` Differential Revision: D71843179 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149980 Approved by: https://github.com/BoyuanFeng
This commit is contained in:
parent
a0253d2840
commit
b2088f1afe
|
|
@ -79,6 +79,22 @@ class TestTorchbind(TestCase):
|
||||||
new_res = compiled(*inputs)
|
new_res = compiled(*inputs)
|
||||||
self.assertTrue(torch.allclose(orig_res, new_res))
|
self.assertTrue(torch.allclose(orig_res, new_res))
|
||||||
|
|
||||||
|
def test_torchbind_compile_symint(self):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.attr = torch.classes._TorchScriptTesting._Foo(2, 3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
a = torch.ops._TorchScriptTesting.takes_foo_tensor_return(self.attr, x)
|
||||||
|
return a
|
||||||
|
|
||||||
|
m = M()
|
||||||
|
inputs = (torch.ones(2, 3),)
|
||||||
|
orig_res = m(*inputs)
|
||||||
|
new_res = torch.compile(m, backend="inductor")(*inputs)
|
||||||
|
self.assertTrue(torch.allclose(orig_res, new_res))
|
||||||
|
|
||||||
def test_torchbind_compile(self):
|
def test_torchbind_compile(self):
|
||||||
_, inputs, orig_res, mod = self.get_exported_model()
|
_, inputs, orig_res, mod = self.get_exported_model()
|
||||||
new_res = torch.compile(mod, backend="inductor")(*inputs)
|
new_res = torch.compile(mod, backend="inductor")(*inputs)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user