mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enabled torch.testing._internal.jit_utils.* typechecking. (#44985)
Summary:
Fixes #{issue number}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44985
Reviewed By: malfet
Differential Revision: D23794444
Pulled By: kauterry
fbshipit-source-id: 9893cc91780338a8223904fb574efa77fa3ab2b9
This commit is contained in:
parent
9f67176b82
commit
4810365576
3
mypy.ini
3
mypy.ini
|
|
@ -56,9 +56,6 @@ ignore_errors = True
|
|||
[mypy-torch.testing._internal.codegen.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.testing._internal.jit_utils.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.testing._internal.autocast_test_lists.*]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -178,6 +178,10 @@ def _jit_get_trigger_value(trigger_name: str) -> _int: ...
|
|||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
ResolutionCallback = Callable[[str], Callable[..., Any]]
|
||||
|
||||
def _create_function_from_graph(qualname: str, graph: Graph) -> Graph: ...
|
||||
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
|
||||
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
|
||||
def _jit_clear_class_registry() -> None: ...
|
||||
def _jit_set_emit_hooks(ModuleHook: Optional[Callable], FunctionHook: Optional[Callable]) -> None: ...
|
||||
def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
|
||||
def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ...
|
||||
|
|
@ -395,7 +399,9 @@ class AggregationType(Enum):
|
|||
AVG = 1
|
||||
|
||||
class FileCheck(object):
|
||||
# TODO
|
||||
# TODO (add more FileCheck signature)
|
||||
def check_source_highlighted(self, highlight: str) -> 'FileCheck': ...
|
||||
def run(self, test_string: str) -> None: ...
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/jit/python/init.cpp
|
||||
|
|
@ -416,6 +422,11 @@ class PyTorchFileWriter(object):
|
|||
def write_end_of_file(self) -> None: ...
|
||||
...
|
||||
|
||||
def _jit_get_inline_everything_mode() -> _bool: ...
|
||||
def _jit_set_inline_everything_mode(enabled: _bool) -> None: ...
|
||||
def _jit_pass_dce(Graph) -> None: ...
|
||||
def _jit_pass_lint(Graph) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_custome_class.cpp
|
||||
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from contextlib import contextmanager
|
|||
from functools import reduce
|
||||
from itertools import chain
|
||||
from torch._six import StringIO
|
||||
from typing import Any, Dict
|
||||
|
||||
import inspect
|
||||
import io
|
||||
|
|
@ -148,14 +149,14 @@ class JitTestCase(TestCase):
|
|||
self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
|
||||
files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
|
||||
# unwrap all the code files into strings
|
||||
code_files = filter(lambda x: x.endswith('.py'), files)
|
||||
code_files = map(lambda f: archive.open(f), code_files)
|
||||
code_files = map(lambda file: "".join([line.decode() for line in file]), code_files)
|
||||
code_files_str = filter(lambda x: x.endswith('.py'), files)
|
||||
code_files_stream = map(lambda f: archive.open(f), code_files_str)
|
||||
code_files = map(lambda file: "".join([line.decode() for line in file]), code_files_stream)
|
||||
|
||||
# unpickled all the debug files
|
||||
debug_files = filter(lambda f: f.endswith('.debug_pkl'), files)
|
||||
debug_files = map(lambda f: archive.open(f), debug_files)
|
||||
debug_files = map(lambda f: pickle.load(f), debug_files)
|
||||
debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files)
|
||||
debug_files_stream = map(lambda f: archive.open(f), debug_files_str)
|
||||
debug_files = map(lambda f: pickle.load(f), debug_files_stream)
|
||||
return code_files, debug_files
|
||||
|
||||
# disable the hook while we parse code, otherwise we will re-enter the hook
|
||||
|
|
@ -336,11 +337,15 @@ class JitTestCase(TestCase):
|
|||
|
||||
def get_frame_vars(self, frames_up):
|
||||
frame = inspect.currentframe()
|
||||
if not frame:
|
||||
raise RuntimeError("failed to inspect frame")
|
||||
i = 0
|
||||
while i < frames_up + 1:
|
||||
frame = frame.f_back
|
||||
if not frame:
|
||||
raise RuntimeError("failed to get frame")
|
||||
i += 1
|
||||
defined_vars = {}
|
||||
defined_vars: Dict[str, Any] = {}
|
||||
defined_vars.update(frame.f_locals)
|
||||
defined_vars.update(frame.f_globals)
|
||||
return defined_vars
|
||||
|
|
@ -408,7 +413,7 @@ class JitTestCase(TestCase):
|
|||
# outputs
|
||||
|
||||
frame = self.get_frame_vars(frames_up)
|
||||
the_locals = {}
|
||||
the_locals: Dict[str, Any] = {}
|
||||
execWrapper(script, glob=frame, loc=the_locals)
|
||||
frame.update(the_locals)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user