mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129763 Approved by: https://github.com/jansel
56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
# Owner(s): ["module: functorch"]
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._functorch
|
|
import torch._inductor
|
|
import torch._inductor.decomposition
|
|
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
|
|
|
|
|
|
class TestTorchbind(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
init_torchbind_implementations()
|
|
|
|
def get_exported_model(self):
|
|
"""
|
|
Returns the ExportedProgram, example inputs, and result from calling the
|
|
eager model with those inputs
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
self.b = torch.randn(2, 3)
|
|
|
|
def forward(self, x):
|
|
x = x + self.b
|
|
a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
|
|
y = a[0] + a[1]
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
|
|
return x + b
|
|
|
|
m = M()
|
|
inputs = (torch.ones(2, 3),)
|
|
orig_res = m(*inputs)
|
|
|
|
# We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
|
|
with enable_torchbind_tracing():
|
|
ep = torch.export.export(m, inputs, strict=False)
|
|
|
|
return ep, inputs, orig_res
|
|
|
|
def test_torchbind_inductor(self):
|
|
ep, inputs, orig_res = self.get_exported_model()
|
|
compiled = torch._inductor.compile(ep.module(), inputs)
|
|
|
|
new_res = compiled(*inputs)
|
|
self.assertTrue(torch.allclose(orig_res, new_res))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|