# Owner(s): ["module: fx"] # # Tests the graph pickler by using pickling on all the inductor tests. # import contextlib import importlib import os import sys from unittest.mock import patch import torch import torch.library from torch._dynamo.testing import make_test_cls_with_patches from torch._inductor.test_case import TestCase from torch.testing._internal.inductor_utils import HAS_CPU # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library check_model, CommonTemplate, copy_tests, TestFailure, ) importlib.import_module("filelock") # xfail by default, set is_skip=True to skip test_failures = { # TypeError: cannot pickle 'generator' object "test_layer_norm_graph_pickler": TestFailure(("cpu"), is_skip=True), } def make_test_cls(cls, xfail_prop="_expected_failure_graph_pickler"): return make_test_cls_with_patches( cls, "GraphPickler", "_graph_pickler", ( torch._inductor.compile_fx, "fx_compile_mode", torch._inductor.compile_fx.FxCompileMode.SERIALIZE, ), xfail_prop=xfail_prop, ) GraphPicklerCommonTemplate = make_test_cls(CommonTemplate) if HAS_CPU: class GraphPicklerCpuTests(TestCase): common = check_model device = "cpu" copy_tests(GraphPicklerCommonTemplate, GraphPicklerCpuTests, "cpu", test_failures) class TestGraphPickler(TestCase): def setUp(self): torch._dynamo.reset() TestCase.setUp(self) self._stack = contextlib.ExitStack() self._stack.enter_context( patch( "torch._inductor.compile_fx.fx_compile_mode", torch._inductor.compile_fx.FxCompileMode.SERIALIZE, ) ) def tearDown(self): self._stack.close() TestCase.tearDown(self) torch._dynamo.reset() def test_simple(self): # Make sure that compiling works when we pass the input + output from # fx_codegen_and_compile() through serde. def fn(a, b): return a + b check_model(self, fn, (torch.tensor([False, True]), torch.tensor([True, True]))) if __name__ == "__main__": raise RuntimeError( "This test is not currently used and should be " "enabled in discover_tests.py if required." )