pytorch/test/dynamo/test_backends.py
Bert Maher d54fcd571a [dynamo] Skip tests that are broken in fbcode (#101217)
Some tests don't work in fbcode, for some reason.  Skip these until we
can figure them out.

Differential Revision: [D45791340](https://our.internmc.facebook.com/intern/diff/D45791340/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101217
Approved by: https://github.com/davidberard98
2023-05-12 14:13:14 +00:00

221 lines
7.0 KiB
Python

# Owner(s): ["module: dynamo"]
import functools
import unittest
import torch
import torch._dynamo
import torch._dynamo.backends.ipex
import torch._dynamo.test_case
from torch._dynamo.backends.ipex import has_ipex
from torch._dynamo.backends.onnxrt import has_onnxruntime
from torch._dynamo.backends.tvm import has_tvm
from torch._dynamo.testing import same
from torch.testing._internal.common_utils import IS_FBCODE
from torch.testing._internal.inductor_utils import HAS_CUDA
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
class Seq(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.Sigmoid(),
)
def forward(self, x):
return self.layers(x)
class Conv_Bn_Relu(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class TestOptimizations(torch._dynamo.test_case.TestCase):
def test_example_inputs(self):
def fn(a, bc, d):
b, c = bc
return a / d - b / c
def compiler_fn(graph, example_inputs):
nonlocal r1
r1 = graph(*example_inputs)[0]
return graph.forward
a = torch.empty(2).fill_(1)
b = torch.empty(2).fill_(2)
c = torch.empty(2).fill_(3)
d = 4
r1 = None
r2 = fn(a, (b, c), d)
opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
r3 = opt_fn(a, (b, c), d)
self.assertIsNotNone(r1)
self.assertEqual(r1.size(), r2.size())
self.assertEqual(r1.stride(), r2.stride())
self.assertEqual(r1.dtype, r2.dtype)
self.assertEqual(r1.size(), r3.size())
self.assertEqual(r1.stride(), r3.stride())
self.assertEqual(r1.dtype, r3.dtype)
def test_example_inputs_runtime_use(self):
def fn(a, bc, d):
b, c = bc
return a / d - b / c
def compiler_fn(graph, example_inputs):
def fwd(*args):
nonlocal r1
r = graph.forward(*args)
r1 = r[0]
return r
return fwd
a = torch.empty(2).fill_(1)
b = torch.empty(2).fill_(2)
c = torch.empty(2).fill_(3)
d = 4
r1 = None
r2 = fn(a, (b, c), d)
opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
r3 = opt_fn(a, (b, c), d)
self.assertIsNotNone(r1)
self.assertTrue(same(r1, r2))
self.assertTrue(same(r1, r3))
@unittest.skipIf(not has_ipex(), "requires ipex")
def test_ipex_fp32(self):
model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
model = model.to(memory_format=torch.channels_last)
model = model.eval()
input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
r1 = model(input)
for dynamic_shapes in [True, False]:
torch._dynamo.reset()
opt_model = torch._dynamo.optimize("ipex", dynamic=dynamic_shapes)(model)
with torch.no_grad():
for _ in range(3):
r2 = opt_model(input)
self.assertTrue(same(r1, r2))
self.assertEqual(r2.dtype, torch.float32)
@unittest.skipIf(not has_ipex(), "requires ipex")
def test_ipex_bf16(self):
model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
model = model.to(memory_format=torch.channels_last)
model = model.eval()
input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
r1 = model(input)
for dynamic_shapes in [True, False]:
torch._dynamo.reset()
opt_model = torch._dynamo.optimize("ipex", dynamic=dynamic_shapes)(model)
with torch.no_grad(), torch.cpu.amp.autocast():
for _ in range(3):
r2 = opt_model(input)
self.assertTrue(same(r1, r2.float(), tol=0.1))
self.assertEqual(r2.dtype, torch.bfloat16)
def _check_backend_works(self, backend):
model = Seq().eval()
input = torch.randn(2, 10)
r1 = model(input)
r2 = torch.compile(model, backend=backend)(input)
self.assertTrue(same(r1, r2.float(), tol=0.01))
def test_eager(self):
self._check_backend_works("eager")
def test_torchscript(self):
self._check_backend_works("ts")
def test_aot_eager(self):
self._check_backend_works("aot_eager")
def test_aot_eager_decomp_partition(self):
self._check_backend_works("aot_eager_decomp_partition")
def test_aot_ts(self):
self._check_backend_works("aot_ts")
@requires_cuda()
def test_aot_cudagraphs(self):
self._check_backend_works("cudagraphs")
@requires_cuda()
def test_aot_ts_nvfuser(self):
self._check_backend_works("aot_ts_nvfuser")
@requires_cuda()
@unittest.skipIf(IS_FBCODE, "BackendCompilerError")
def test_nvprims_nvfuser(self):
self._check_backend_works("nvprims_nvfuser")
@requires_cuda()
@unittest.skipIf(IS_FBCODE, "BackendCompilerError")
def test_nvprims_aten(self):
self._check_backend_works("nvprims_aten")
@unittest.skipIf(not has_onnxruntime(), "requires onnxruntime")
def test_onnxrt(self):
self._check_backend_works("onnxrt")
@unittest.skipIf(not has_tvm(), "requires tvm")
def test_tvm(self):
self._check_backend_works("tvm")
def test_list_backends(self):
self.assertIn("inductor", torch._dynamo.list_backends())
self.assertIn("inductor", torch._dynamo.list_backends(exclude_tags=None))
self.assertNotIn("eager", torch._dynamo.list_backends())
self.assertNotIn("eager", torch._dynamo.list_backends(exclude_tags=["debug"]))
self.assertIn("eager", torch._dynamo.list_backends(exclude_tags=[]))
class NormalizeIRTests(torch._dynamo.test_case.TestCase):
def test_inplace_normalize(self):
def fn(a, b):
x = torch.cos(a)
x += b
return torch.sin(x)
a = torch.randn(10)
b = torch.randn(10).to(torch.float64)
ref = fn(a, b)
optimized_fn = torch._dynamo.optimize("aot_eager")(fn)
res = optimized_fn(a, b)
self.assertTrue(same(ref, res))
class MPSNotSupportedTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf(not torch.backends.mps.is_available(), "requires mps")
def test_mps_not_supported(self):
model = Seq().to("mps")
example_input = torch.randn(1, 10).to("mps")
self.assertRaises(
RuntimeError,
lambda: torch.compile(model, backend="inductor")(example_input),
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()