Enabled torch.testing._internal.jit_utils.* typechecking. (#44985)

Summary:
Fixes #{issue number}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44985

Reviewed By: malfet

Differential Revision: D23794444

Pulled By: kauterry

fbshipit-source-id: 9893cc91780338a8223904fb574efa77fa3ab2b9
This commit is contained in:
Kaushik Ram Sadagopan 2020-09-21 01:14:57 -07:00 committed by Facebook GitHub Bot
parent 9f67176b82
commit 4810365576
3 changed files with 25 additions and 12 deletions

View File

@ -56,9 +56,6 @@ ignore_errors = True
[mypy-torch.testing._internal.codegen.*]
ignore_errors = True
[mypy-torch.testing._internal.jit_utils.*]
ignore_errors = True
[mypy-torch.testing._internal.autocast_test_lists.*]
ignore_errors = True

View File

@ -178,6 +178,10 @@ def _jit_get_trigger_value(trigger_name: str) -> _int: ...
# Defined in torch/csrc/jit/python/script_init.cpp
ResolutionCallback = Callable[[str], Callable[..., Any]]
def _create_function_from_graph(qualname: str, graph: Graph) -> Graph: ...
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
def _jit_clear_class_registry() -> None: ...
def _jit_set_emit_hooks(ModuleHook: Optional[Callable], FunctionHook: Optional[Callable]) -> None: ...
def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ...
@ -395,7 +399,9 @@ class AggregationType(Enum):
AVG = 1
class FileCheck(object):
# TODO
# TODO (add more FileCheck signature)
def check_source_highlighted(self, highlight: str) -> 'FileCheck': ...
def run(self, test_string: str) -> None: ...
...
# Defined in torch/csrc/jit/python/init.cpp
@ -416,6 +422,11 @@ class PyTorchFileWriter(object):
def write_end_of_file(self) -> None: ...
...
def _jit_get_inline_everything_mode() -> _bool: ...
def _jit_set_inline_everything_mode(enabled: _bool) -> None: ...
def _jit_pass_dce(Graph) -> None: ...
def _jit_pass_lint(Graph) -> None: ...
# Defined in torch/csrc/jit/python/python_custome_class.cpp
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...

View File

@ -24,6 +24,7 @@ from contextlib import contextmanager
from functools import reduce
from itertools import chain
from torch._six import StringIO
from typing import Any, Dict
import inspect
import io
@ -148,14 +149,14 @@ class JitTestCase(TestCase):
self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
# unwrap all the code files into strings
code_files = filter(lambda x: x.endswith('.py'), files)
code_files = map(lambda f: archive.open(f), code_files)
code_files = map(lambda file: "".join([line.decode() for line in file]), code_files)
code_files_str = filter(lambda x: x.endswith('.py'), files)
code_files_stream = map(lambda f: archive.open(f), code_files_str)
code_files = map(lambda file: "".join([line.decode() for line in file]), code_files_stream)
# unpickled all the debug files
debug_files = filter(lambda f: f.endswith('.debug_pkl'), files)
debug_files = map(lambda f: archive.open(f), debug_files)
debug_files = map(lambda f: pickle.load(f), debug_files)
debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files)
debug_files_stream = map(lambda f: archive.open(f), debug_files_str)
debug_files = map(lambda f: pickle.load(f), debug_files_stream)
return code_files, debug_files
# disable the hook while we parse code, otherwise we will re-enter the hook
@ -336,11 +337,15 @@ class JitTestCase(TestCase):
def get_frame_vars(self, frames_up):
frame = inspect.currentframe()
if not frame:
raise RuntimeError("failed to inspect frame")
i = 0
while i < frames_up + 1:
frame = frame.f_back
if not frame:
raise RuntimeError("failed to get frame")
i += 1
defined_vars = {}
defined_vars: Dict[str, Any] = {}
defined_vars.update(frame.f_locals)
defined_vars.update(frame.f_globals)
return defined_vars
@ -408,7 +413,7 @@ class JitTestCase(TestCase):
# outputs
frame = self.get_frame_vars(frames_up)
the_locals = {}
the_locals: Dict[str, Any] = {}
execWrapper(script, glob=frame, loc=the_locals)
frame.update(the_locals)