diff --git a/mypy.ini b/mypy.ini index deab2aeaee3..a7d4acea957 100644 --- a/mypy.ini +++ b/mypy.ini @@ -56,9 +56,6 @@ ignore_errors = True [mypy-torch.testing._internal.codegen.*] ignore_errors = True -[mypy-torch.testing._internal.jit_utils.*] -ignore_errors = True - [mypy-torch.testing._internal.autocast_test_lists.*] ignore_errors = True diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 29995760e00..0d48ea710fd 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -178,6 +178,10 @@ def _jit_get_trigger_value(trigger_name: str) -> _int: ... # Defined in torch/csrc/jit/python/script_init.cpp ResolutionCallback = Callable[[str], Callable[..., Any]] +def _create_function_from_graph(qualname: str, graph: Graph) -> Graph: ... +def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ... +def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ... +def _jit_clear_class_registry() -> None: ... def _jit_set_emit_hooks(ModuleHook: Optional[Callable], FunctionHook: Optional[Callable]) -> None: ... def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ... def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ... @@ -395,7 +399,9 @@ class AggregationType(Enum): AVG = 1 class FileCheck(object): - # TODO + # TODO (add more FileCheck signature) + def check_source_highlighted(self, highlight: str) -> 'FileCheck': ... + def run(self, test_string: str) -> None: ... ... # Defined in torch/csrc/jit/python/init.cpp @@ -416,6 +422,11 @@ class PyTorchFileWriter(object): def write_end_of_file(self) -> None: ... ... +def _jit_get_inline_everything_mode() -> _bool: ... +def _jit_set_inline_everything_mode(enabled: _bool) -> None: ... +def _jit_pass_dce(Graph) -> None: ... +def _jit_pass_lint(Graph) -> None: ... + # Defined in torch/csrc/jit/python/python_custome_class.cpp def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ... diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index 542c182b036..732260573ec 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -24,6 +24,7 @@ from contextlib import contextmanager from functools import reduce from itertools import chain from torch._six import StringIO +from typing import Any, Dict import inspect import io @@ -148,14 +149,14 @@ class JitTestCase(TestCase): self.assertEqual(len(set(archive.namelist())), len(archive.namelist())) files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist())) # unwrap all the code files into strings - code_files = filter(lambda x: x.endswith('.py'), files) - code_files = map(lambda f: archive.open(f), code_files) - code_files = map(lambda file: "".join([line.decode() for line in file]), code_files) + code_files_str = filter(lambda x: x.endswith('.py'), files) + code_files_stream = map(lambda f: archive.open(f), code_files_str) + code_files = map(lambda file: "".join([line.decode() for line in file]), code_files_stream) # unpickled all the debug files - debug_files = filter(lambda f: f.endswith('.debug_pkl'), files) - debug_files = map(lambda f: archive.open(f), debug_files) - debug_files = map(lambda f: pickle.load(f), debug_files) + debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files) + debug_files_stream = map(lambda f: archive.open(f), debug_files_str) + debug_files = map(lambda f: pickle.load(f), debug_files_stream) return code_files, debug_files # disable the hook while we parse code, otherwise we will re-enter the hook @@ -336,11 +337,15 @@ class JitTestCase(TestCase): def get_frame_vars(self, frames_up): frame = inspect.currentframe() + if not frame: + raise RuntimeError("failed to inspect frame") i = 0 while i < frames_up + 1: frame = frame.f_back + if not frame: + raise RuntimeError("failed to get frame") i += 1 - defined_vars = {} + defined_vars: Dict[str, Any] = {} defined_vars.update(frame.f_locals) defined_vars.update(frame.f_globals) return defined_vars @@ -408,7 +413,7 @@ class JitTestCase(TestCase): # outputs frame = self.get_frame_vars(frames_up) - the_locals = {} + the_locals: Dict[str, Any] = {} execWrapper(script, glob=frame, loc=the_locals) frame.update(the_locals)