mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
enable mypy check for jit_metaprogramming_utils (#44752)
Summary: Fixes https://github.com/pytorch/pytorch/issues/42969 enable mypy check for jit_metaprogramming_utils.py and fixed all errors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/44752 Reviewed By: walterddr Differential Revision: D23741285 Pulled By: qxu-fb fbshipit-source-id: 21e36ca5d25c8682fb93b806e416b9e1db76f71e
This commit is contained in:
parent
3f5bb2bade
commit
09a84071a3
3
mypy.ini
3
mypy.ini
|
|
@ -83,9 +83,6 @@ ignore_errors = True
|
|||
[mypy-torch.testing._internal.common_methods_invocations.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.testing._internal.jit_metaprogramming_utils.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.testing._internal.common_nn.*]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import List
|
||||
|
||||
# Torch
|
||||
from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401
|
||||
from torch.testing._internal.common_methods_invocations import non_differentiable, create_input, \
|
||||
|
|
@ -242,7 +244,7 @@ def get_call(method_name, func_type, args, kwargs):
|
|||
elif func_type == 'nn_functional':
|
||||
call = 'torch.nn.functional.{}({})'.format(method_name, argument_str)
|
||||
else:
|
||||
raise 'Unsupported function type'
|
||||
raise TypeError('Unsupported function type')
|
||||
|
||||
return call
|
||||
|
||||
|
|
@ -254,9 +256,9 @@ def get_constant(x):
|
|||
return x
|
||||
|
||||
def get_script_args(args):
|
||||
formals = []
|
||||
tensors = []
|
||||
actuals = []
|
||||
formals: List[str] = []
|
||||
tensors: List[torch.Tensor] = []
|
||||
actuals: List[str] = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
name = 'i{}'.format(len(formals))
|
||||
|
|
@ -286,7 +288,8 @@ def create_script_fn(self, method_name, func_type, output_process_fn):
|
|||
fn, tensors = gen_script_fn_and_args(method_name, func_type, *args, **kwargs)
|
||||
self.assertExportImport(fn.graph, tensors)
|
||||
output = output_process_fn(fn(*tensors))
|
||||
script_fn.last_graph = fn.graph_for(*tensors)
|
||||
# skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
|
||||
script_fn.last_graph = fn.graph_for(*tensors) # type: ignore[attr-defined]
|
||||
return output
|
||||
return script_fn
|
||||
|
||||
|
|
@ -312,7 +315,8 @@ def create_traced_fn(self, fn):
|
|||
traced = torch.jit.trace(fn_tensors, inputs_tensors, check_trace=False)
|
||||
self.assertExportImport(traced.graph, inputs_tensors)
|
||||
output = traced(*inputs_tensors)
|
||||
traced_fn.last_graph = traced.graph_for(*inputs_tensors)
|
||||
# skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
|
||||
traced_fn.last_graph = traced.graph_for(*inputs_tensors) # type: ignore[attr-defined]
|
||||
return output
|
||||
return traced_fn
|
||||
|
||||
|
|
@ -450,7 +454,8 @@ def create_script_module(self, nn_module, constructor_args, *args, **kwargs):
|
|||
if self:
|
||||
self.assertExportImportModule(module, tensors)
|
||||
module(*args)
|
||||
create_script_module.last_graph = module.graph
|
||||
# skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
|
||||
create_script_module.last_graph = module.graph # type: ignore[attr-defined]
|
||||
return module
|
||||
return script_module
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user