# Owner(s): ["module: fx"] # # Tests the graph pickler by using pickling on all the inductor tests. # import contextlib import importlib import os import sys from unittest.mock import patch import torch import torch.library from torch._dynamo.testing import make_test_cls_with_patches from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import TEST_WITH_ASAN from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library check_model, CommonTemplate, copy_tests, ) importlib.import_module("filelock") # xfail by default, set is_skip=True to skip test_failures = {} def make_test_cls(cls, xfail_prop="_expected_failure_graph_pickler"): return make_test_cls_with_patches( cls, "GraphPickler", "_graph_pickler", ( torch._inductor.compile_fx, "fx_compile_mode", torch._inductor.compile_fx.FxCompileMode.SERIALIZE, ), xfail_prop=xfail_prop, ) GraphPicklerCommonTemplate = make_test_cls(CommonTemplate) if HAS_CPU: class GraphPicklerCpuTests(TestCase): common = check_model device = "cpu" copy_tests(GraphPicklerCommonTemplate, GraphPicklerCpuTests, "cpu", test_failures) class TestGraphPickler(TestCase): def setUp(self): torch._dynamo.reset() TestCase.setUp(self) self._stack = contextlib.ExitStack() self._stack.enter_context( patch( "torch._inductor.compile_fx.fx_compile_mode", torch._inductor.compile_fx.FxCompileMode.SERIALIZE, ) ) def tearDown(self): self._stack.close() TestCase.tearDown(self) torch._dynamo.reset() def test_simple(self): # Make sure that compiling works when we pass the input + output from # fx_codegen_and_compile() through serde. def fn(a, b): return a + b check_model(self, fn, (torch.tensor([False, True]), torch.tensor([True, True]))) if __name__ == "__main__": from torch._inductor.test_case import run_tests # Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068 if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN: run_tests(needs="filelock")