mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105431 Approved by: https://github.com/albanD
116 lines
3.5 KiB
Python
116 lines
3.5 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import torch
|
|
from torch import _dynamo as dynamo, _inductor as inductor
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import gen_gm_and_inputs
|
|
from torch.fx import symbolic_trace
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.testing._internal.inductor_utils import HAS_CPU
|
|
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.a = torch.nn.Linear(10, 10)
|
|
self.b = torch.nn.Linear(10, 10)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.relu(self.a(x))
|
|
x = torch.sigmoid(self.b(x))
|
|
return x
|
|
|
|
|
|
class MyModule2(MyModule):
|
|
def forward(self, x): # takes a dict of list
|
|
a, b = x["key"]
|
|
return {"result": super().forward(a) + b}
|
|
|
|
|
|
class MyModule3(MyModule):
|
|
def forward(self, x):
|
|
return (super().forward(x),)
|
|
|
|
|
|
class TestStandaloneInductor(TestCase):
|
|
"""
|
|
These test check that you can call TorchInductor directly without
|
|
going through TorchDynamo.
|
|
"""
|
|
|
|
def test_inductor_via_fx(self):
|
|
mod = MyModule3().eval()
|
|
inp = torch.randn(10)
|
|
correct = mod(inp)
|
|
mod_opt = inductor.compile(symbolic_trace(mod), [inp])
|
|
actual = mod_opt(inp)
|
|
self.assertEqual(actual, correct)
|
|
|
|
def test_inductor_via_fx_tensor_return(self):
|
|
mod = MyModule().eval()
|
|
inp = torch.randn(10)
|
|
correct = mod(inp)
|
|
mod_opt = inductor.compile(symbolic_trace(mod), [inp])
|
|
actual = mod_opt(inp)
|
|
self.assertEqual(actual, correct)
|
|
|
|
def test_inductor_via_fx_dict_input(self):
|
|
mod = MyModule2().eval()
|
|
inp = {"key": [torch.randn(10), torch.randn(10)]}
|
|
correct = mod(inp)
|
|
mod_opt = inductor.compile(symbolic_trace(mod), [inp])
|
|
actual = mod_opt(inp)
|
|
self.assertEqual(actual, correct)
|
|
|
|
def test_inductor_via_make_fx(self):
|
|
mod = MyModule().eval()
|
|
inp = torch.randn(10)
|
|
correct = mod(inp)
|
|
mod_opt = inductor.compile(make_fx(mod)(inp), [inp])
|
|
actual = mod_opt(inp)
|
|
self.assertEqual(actual, correct)
|
|
|
|
def test_inductor_via_bare_module(self):
|
|
mod = MyModule3().eval()
|
|
inp = torch.randn(10)
|
|
correct = mod(inp)
|
|
# no FX graph at all (mod must return list/tuple in this case)
|
|
mod_opt = inductor.compile(mod, [inp])
|
|
actual = mod_opt(inp)
|
|
self.assertEqual(actual, correct)
|
|
|
|
def test_inductor_via_export1(self):
|
|
mod = MyModule3().eval()
|
|
inp = torch.randn(10)
|
|
correct = mod(inp)
|
|
gm, guards = dynamo.export(mod, inp, aten_graph=True)
|
|
mod_opt = inductor.compile(gm, [inp])
|
|
actual = mod_opt(inp)
|
|
self.assertEqual(actual, correct)
|
|
|
|
def test_inductor_via_export2(self):
|
|
mod = MyModule2().eval()
|
|
inp = {"key": [torch.randn(10), torch.randn(10)]}
|
|
correct = mod(inp)
|
|
gm, guards = dynamo.export(mod, inp)
|
|
mod_opt = inductor.compile(gm, [inp])
|
|
actual = mod_opt(inp)
|
|
self.assertEqual(actual, correct)
|
|
|
|
def test_inductor_via_op_with_multiple_outputs(self):
|
|
x1 = torch.randn((2, 512, 128))
|
|
x2 = [128]
|
|
x3 = torch.randn(128)
|
|
x4 = torch.randn((128,))
|
|
x5 = 1e-6
|
|
mod, inp = gen_gm_and_inputs(
|
|
torch.ops.aten.native_layer_norm.default, (x1, x2, x3, x4, x5), {}
|
|
)
|
|
mod_opt = inductor.compile(mod, inp)
|
|
self.assertEqual(mod(*inp), mod_opt(*inp))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if HAS_CPU:
|
|
run_tests()
|