enable mypy check for jit_metaprogramming_utils (#44752)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/42969
enable mypy check for jit_metaprogramming_utils.py and fixed all errors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44752

Reviewed By: walterddr

Differential Revision: D23741285

Pulled By: qxu-fb

fbshipit-source-id: 21e36ca5d25c8682fb93b806e416b9e1db76f71e
This commit is contained in:
qxu 2020-09-16 15:33:53 -07:00 committed by Facebook GitHub Bot
parent 3f5bb2bade
commit 09a84071a3
2 changed files with 12 additions and 10 deletions

View File

@ -83,9 +83,6 @@ ignore_errors = True
[mypy-torch.testing._internal.common_methods_invocations.*]
ignore_errors = True
[mypy-torch.testing._internal.jit_metaprogramming_utils.*]
ignore_errors = True
[mypy-torch.testing._internal.common_nn.*]
ignore_errors = True

View File

@ -1,3 +1,5 @@
from typing import List
# Torch
from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401
from torch.testing._internal.common_methods_invocations import non_differentiable, create_input, \
@ -242,7 +244,7 @@ def get_call(method_name, func_type, args, kwargs):
elif func_type == 'nn_functional':
call = 'torch.nn.functional.{}({})'.format(method_name, argument_str)
else:
raise 'Unsupported function type'
raise TypeError('Unsupported function type')
return call
@ -254,9 +256,9 @@ def get_constant(x):
return x
def get_script_args(args):
formals = []
tensors = []
actuals = []
formals: List[str] = []
tensors: List[torch.Tensor] = []
actuals: List[str] = []
for arg in args:
if isinstance(arg, torch.Tensor):
name = 'i{}'.format(len(formals))
@ -286,7 +288,8 @@ def create_script_fn(self, method_name, func_type, output_process_fn):
fn, tensors = gen_script_fn_and_args(method_name, func_type, *args, **kwargs)
self.assertExportImport(fn.graph, tensors)
output = output_process_fn(fn(*tensors))
script_fn.last_graph = fn.graph_for(*tensors)
# skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
script_fn.last_graph = fn.graph_for(*tensors) # type: ignore[attr-defined]
return output
return script_fn
@ -312,7 +315,8 @@ def create_traced_fn(self, fn):
traced = torch.jit.trace(fn_tensors, inputs_tensors, check_trace=False)
self.assertExportImport(traced.graph, inputs_tensors)
output = traced(*inputs_tensors)
traced_fn.last_graph = traced.graph_for(*inputs_tensors)
# skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
traced_fn.last_graph = traced.graph_for(*inputs_tensors) # type: ignore[attr-defined]
return output
return traced_fn
@ -450,7 +454,8 @@ def create_script_module(self, nn_module, constructor_args, *args, **kwargs):
if self:
self.assertExportImportModule(module, tensors)
module(*args)
create_script_module.last_graph = module.graph
# skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
create_script_module.last_graph = module.graph # type: ignore[attr-defined]
return module
return script_module