diff --git a/mypy.ini b/mypy.ini index f4b37f15a82..0b9f5497162 100644 --- a/mypy.ini +++ b/mypy.ini @@ -143,30 +143,6 @@ ignore_errors = True [mypy-torch.nn.intrinsic.qat.modules.conv_fused] ignore_errors = True -[mypy-torch.onnx.operators] -ignore_errors = True - -[mypy-torch.onnx.symbolic_opset8] -ignore_errors = True - -[mypy-torch.onnx.symbolic_opset9] -ignore_errors = True - -[mypy-torch.onnx.symbolic_opset11] -ignore_errors = True - -[mypy-torch.onnx.symbolic_caffe2] -ignore_errors = True - -[mypy-torch.onnx.symbolic_helper] -ignore_errors = True - -[mypy-torch.onnx.symbolic_registry] -ignore_errors = True - -[mypy-torch.onnx.utils] -ignore_errors = True - [mypy-torch.multiprocessing.pool] ignore_errors = True diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index cbb5b2452e2..a7f1f1b91c9 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -165,7 +165,10 @@ def wait(fut: Future) -> Any: ... def _collect_all(futures: List[Future]) -> Future: ... def unify_type_list(types: List[JitType]) -> JitType: ... -def _freeze_module(module: ScriptModule, preserved_attrs: List[str], freeze_interfaces: _bool = True) -> ScriptModule: ... +def _freeze_module(module: ScriptModule, + preserved_attrs: List[str] = [], + freeze_interfaces: _bool = True, + preserveParameters: _bool = True) -> ScriptModule: ... def _is_tracing() -> _bool: ... def _jit_init() -> _bool: ... def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ... @@ -217,6 +220,8 @@ def _jit_get_trigger_value(trigger_name: str) -> _int: ... # Defined in torch/csrc/jit/python/script_init.cpp ResolutionCallback = Callable[[str], Callable[..., Any]] +# Defined in torch/csrc/jit/python/script_init.cpp +# and torch/csrc/jit/python/init.cpp def _create_function_from_graph(qualname: str, graph: Graph) -> Graph: ... def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ... def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ... @@ -246,6 +251,55 @@ def _resolve_type_from_object(obj: Any, range: SourceRange, rcb: ResolutionCallb def _create_module_with_type(ty: JitType) -> ScriptModule: ... def _run_emit_module_hook(m: ScriptModule): ... def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def, new_name: str) -> Def: ... + +def _jit_pass_lower_all_tuples(graph: Graph) -> None: ... +def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ... +def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, opset_version: _int) -> None: ... +def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], onnx_shape_inference: _bool = False) -> None: ... +def _jit_pass_fixup_onnx_loop_node_inputs(n: Node) -> None: ... +def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph) -> None: ... +def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ... +def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ... +def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ... +def _jit_pass_fuse_addmm(graph: Graph) -> None: ... +def _jit_pass_onnx_preprocess(graph: Graph) -> None: ... +def _jit_pass_onnx_prepare_inplace_ops_for_onnx(graph: Graph) -> None: ... +def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ... +def _jit_pass_onnx_remove_print(graph: Graph) -> None: ... +def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ... +def _jit_pass_onnx_unpack_quantized_weights( + graph: Graph, + paramsDict: Dict[str, IValue] +) -> Dict[str, IValue]: ... +def _jit_pass_onnx_quantization_insert_permutes( + graph: Graph, + paramsDict: Dict[str, IValue] +) -> Dict[str, IValue]: ... +def _jit_pass_custom_pattern_based_rewrite_graph(pattern: str, fused_node_name: str, graph: Graph) -> None: ... +def _jit_onnx_list_model_parameters(module: ScriptModule) -> Tuple[ScriptModule, List[IValue]]: ... +def _jit_pass_erase_number_types(graph: Graph) -> None: ... +def _jit_pass_onnx(graph: Graph, _jit_pass_onnx: _onnx.OperatorExportTypes) -> Graph: ... +def _jit_pass_onnx_scalar_type_analysis(graph: Graph) -> None: ... +def _jit_pass_onnx_peephole(graph: Graph, opset_version: _int, fixed_batch_size: _bool) -> None: ... +def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ... +def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ... +def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ... +def _jit_pass_inline_fork_wait(graph: Graph) -> None: ... +def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ... +def _jit_pass_onnx_constant_fold(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> Dict[str, IValue]: ... +def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ... +def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ... +def _jit_pass_filter_non_tensor_arguments(params: Dict[str, IValue]) -> Dict[str, Tensor]: ... +def _jit_decay_packed_param_input_types(graph: Graph) -> None: ... +def _jit_pass_onnx_node_shape_type_inference(n: Node, opset_version: _int) -> None: ... +def _jit_pass_onnx_block( + old_block: Block, + new_block: Block, + operator_export_type: _onnx.OperatorExportTypes, + env: Dict[Value, Value] +) -> None: ... +def _jit_pass_fixup_onnx_controlflow_node(n: Node, opset_version: _int) -> Node: ... + def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ... def _jit_script_compile_overload( qualname: str, @@ -281,8 +335,18 @@ def import_ir_module_from_buffer( extra_files: Dict[str, Any] ) -> ScriptModule: ... +def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ... +def _check_onnx_proto(proto: str) -> None: ... +def _propagate_and_assign_input_shapes( + graph: Graph, + inputs: Tuple[Tensor, ...], + with_grad: _bool, + propagate: _bool +) -> Graph: ... + # Defined in torch/torch/csrc/jit/ir/ir.h class Graph: + def eraseInput(self, i: _int) -> None: ... ... # Defined in torch/csrc/jit/ir/ir.h @@ -366,8 +430,8 @@ class ScriptFunction: def qualified_name(self) -> str: ... class ScriptMethod: + graph: Graph ... - class ModuleDict: def __init__(self, mod: ScriptModule) -> None: ... def items(self) -> List[Tuple[str, Any]]: ... @@ -378,6 +442,10 @@ class ParameterDict: class BufferDict: def __init__(self, mod: ScriptModule) -> None: ... +# Defined in torch/csrc/jit/api/module.h +class Module: + ... + # Defined in torch/csrc/Module.cpp def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension def _autograd_init() -> _bool: ... # THPAutograd_initExtension diff --git a/torch/_C/_onnx.pyi b/torch/_C/_onnx.pyi index 51f16566ce6..7ab3cd9c567 100644 --- a/torch/_C/_onnx.pyi +++ b/torch/_C/_onnx.pyi @@ -29,6 +29,7 @@ class OperatorExportTypes(Enum): ONNX_ATEN = ... ONNX_ATEN_FALLBACK = ... RAW = ... + ONNX_FALLTHROUGH = ... class TrainingMode(Enum): EVAL = ... diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 5e9430f995f..10250baf131 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -2,6 +2,7 @@ import torch import warnings from sys import maxsize as maxsize +from typing import Set import torch.onnx # This import monkey-patches graph manipulation methods on Graph, used for the @@ -125,7 +126,7 @@ def parse_args(*arg_descriptors): def wrapper(g, *args, **kwargs): # some args may be optional, so the length may be smaller assert len(arg_descriptors) >= len(args) - args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] + args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] # type: ignore # only support _outputs in kwargs assert len(kwargs) <= 1 if len(kwargs) == 1: @@ -232,18 +233,18 @@ def _select_helper(g, self, dim, index, apply_reshape=True): def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False): if _export_onnx_opset_version <= 9: - from torch.onnx.symbolic_opset9 import _slice - return _slice(g, input, axes, starts, ends) + from torch.onnx.symbolic_opset9 import _slice as _slice9 + return _slice9(g, input, axes, starts, ends) else: - from torch.onnx.symbolic_opset10 import _slice - return _slice(g, input, axes, starts, ends, steps, dynamic_slice) + from torch.onnx.symbolic_opset10 import _slice as _slice10 + return _slice10(g, input, axes, starts, ends, steps, dynamic_slice) def _hardtanh_helper(g, input, min_val, max_val): if _export_onnx_opset_version <= 10: from torch.onnx.symbolic_opset9 import hardtanh return hardtanh(g, input, min_val, max_val) else: - from torch.onnx.symbolic_opset11 import hardtanh + from torch.onnx.symbolic_opset11 import hardtanh # type: ignore[no-redef] return hardtanh(g, input, min_val, max_val) def _is_fp(value): @@ -380,7 +381,7 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode , align_ size = g.op("Concat", *size, axis_i=0) scale_factor = _interpolate_size_to_scales(g, input, size, dim) else: - return _unimplemented("Both size and scales are None in __interpolate") + return _unimplemented("interpolate", "Both size and scales are None in __interpolate") return scale_factor, mode @@ -388,7 +389,7 @@ def _unbind_helper(g, self, dim, _outputs): if _export_onnx_opset_version <= 9: from torch.onnx.symbolic_opset9 import unbind else: - from torch.onnx.symbolic_opset11 import unbind + from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef] return unbind(g, self, dim, _outputs) @@ -396,7 +397,8 @@ def _scatter_helper(g, self, dim, index, src): if _export_onnx_opset_version <= 10: from torch.onnx.symbolic_opset9 import scatter else: - from torch.onnx.symbolic_opset11 import scatter + # for mypy, scatter was imported two lines above + from torch.onnx.symbolic_opset11 import scatter # type: ignore return scatter(g, self, dim, index, src) @@ -444,7 +446,8 @@ def _index_fill_reshape_helper(g, self, dim, index): if _export_onnx_opset_version <= 10: from torch.onnx.symbolic_opset9 import scatter else: - from torch.onnx.symbolic_opset11 import scatter + # for mypy, scatter was imported two lines above + from torch.onnx.symbolic_opset11 import scatter # type: ignore if self.type().dim() is None: return _unimplemented("index_fill", "input rank not accesible") @@ -632,4 +635,4 @@ scalar_type_to_onnx = [ # Global set to store the list of quantized operators in the network. # This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX. -_quantized_ops = set() +_quantized_ops: Set[int] = set() diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index c0c1d48ebec..e4023dab232 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -4,7 +4,7 @@ import torch.onnx.symbolic_helper as sym_help import torch.onnx.symbolic_opset9 as sym_opset9 from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type -from torch.onnx.symbolic_opset9 import _cast_Float +from torch.onnx.symbolic_opset9 import _cast_Float # type: ignore import warnings diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index e395ce5c703..8630f48a62a 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -13,6 +13,8 @@ from functools import wraps import torch.onnx.symbolic_helper as sym_help from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented +from typing import Optional + import numpy import math import warnings @@ -311,7 +313,7 @@ def _maybe_cast_reduce_op_input(g, self): if dtype is not None: # pytorch reduce-ops cast all other integral types to int64 if not sym_help._is_fp(self) and not (dtype == 'Long'): - self = _cast_Long(g, self, False) + self = _cast_Long(g, self, False) # type: ignore return self @@ -2092,7 +2094,7 @@ def _pack_padded_sequence(g, input, lengths, batch_first): # It's really only necessary because those operators expand to something that # only works with int32 types in Caffe2... if lengths.type().scalarType() != 'Int': - lengths = _cast_Int(g, lengths, False) + lengths = _cast_Int(g, lengths, False) # type: ignore return g.op("prim::PackPadded", input, lengths, outputs=2) @@ -2436,7 +2438,7 @@ def arange(g, *args): def masked_fill(g, self, mask, value): - mask = _cast_Bool(g, mask, False) + mask = _cast_Bool(g, mask, False) # type: ignore value = sym_help._maybe_get_scalar(value) return g.op('Where', mask, sym_help._if_scalar_type_as(g, value, self), self) @@ -2734,6 +2736,7 @@ def as_strided(g, self, sizes, strides, offset=None): sizes = sym_help._maybe_get_const(sizes, 'is') rank = len(strides) self_1d = g.op("Reshape", self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) + ind: Optional[torch.Tensor] if not sym_help._is_value(sizes): ind = torch.tensor([0], dtype=torch.long) for i, (size, stride) in enumerate(zip(sizes, strides)): diff --git a/torch/onnx/symbolic_registry.py b/torch/onnx/symbolic_registry.py index 48114d6c472..c059e8f2eb3 100644 --- a/torch/onnx/symbolic_registry.py +++ b/torch/onnx/symbolic_registry.py @@ -1,6 +1,7 @@ import warnings import importlib from inspect import getmembers, isfunction +from typing import Dict, Tuple, Any, Union # The symbolic registry "_registry" is a dictionary that maps operators # (for a specific domain and opset version) to their symbolic functions. @@ -8,9 +9,9 @@ from inspect import getmembers, isfunction # The keys are tuples (domain, version), (where domain is a string, and version is an int), # and the operator's name (string). # The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic -_registry = {} +_registry: Dict[Tuple[str, int], Dict] = {} -_symbolic_versions = {} +_symbolic_versions: Dict[Union[int, str], Any] = {} from torch.onnx.symbolic_helper import _onnx_stable_opsets for opset_version in _onnx_stable_opsets: module = importlib.import_module('torch.onnx.symbolic_opset{}'.format(opset_version)) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 5c41306b9ee..3fe19a56c12 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -18,6 +18,7 @@ from torch._six import string_classes from torch.jit import _unique_state_dict from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _check_onnx_proto +from typing import Union, Tuple, List # the flag to tell the user whether it's in the middle of ONNX export or not @@ -76,7 +77,7 @@ def export(model, args, f, export_params=True, verbose=False, training=None, if aten or export_raw_ir: assert operator_export_type is None assert aten ^ export_raw_ir - operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW + operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW elif operator_export_type is None: if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE: operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK @@ -351,6 +352,7 @@ def _trace_and_get_graph_from_model(model, args): def _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes): torch_out = None + params: Union[List, Tuple] if isinstance(model, torch.jit.ScriptModule): try: graph = model.forward.graph @@ -442,7 +444,7 @@ def _model_to_graph(model, args, verbose=False, param_names = input_and_param_names[len(input_and_param_names) - len(params):] params_dict = dict(zip(param_names, params)) - if training is None or training == TrainingMode.EVAL or (training == TrainingMode.PRESERVE and not is_originally_training): + if training is None or training == TrainingMode.EVAL: params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions: @@ -476,7 +478,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t if aten or export_raw_ir: assert operator_export_type is None assert aten ^ export_raw_ir - operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW + operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW elif operator_export_type is None: operator_export_type = OperatorExportTypes.ONNX return _export_to_pretty_string(model, args, f, export_params, verbose, training, @@ -1051,6 +1053,10 @@ def _graph_constant(g, value, dims, type, *args, **kwargs): dims = [1] isscalar = True type = type.lower() + tensor: Union[torch.CharTensor, torch.ShortTensor, + torch.IntTensor, torch.LongTensor, + torch.HalfTensor, torch.FloatTensor, + torch.DoubleTensor] if type == "char": tensor = torch.CharTensor(*dims) elif type == "short": @@ -1068,7 +1074,7 @@ def _graph_constant(g, value, dims, type, *args, **kwargs): else: raise ValueError("Unknown type, type should be one of the following strings: " "char, short, int, long, half, float, double") - tensor.fill_(value) + tensor.fill_(value) # type: ignore if isscalar: return g.op("Constant", *args, value_z=tensor, **kwargs) return g.op("Constant", *args, value_t=tensor, **kwargs) @@ -1141,8 +1147,8 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): dynamic_axes[key] = value_dict -torch._C.Graph.op = _graph_op -torch._C.Graph.at = _graph_at -torch._C.Block.op = _block_op -torch._C.Graph.constant = _graph_constant -torch._C.Node.__getitem__ = _node_getitem +torch._C.Graph.op = _graph_op # type: ignore +torch._C.Graph.at = _graph_at # type: ignore +torch._C.Block.op = _block_op # type: ignore +torch._C.Graph.constant = _graph_constant # type: ignore +torch._C.Node.__getitem__ = _node_getitem # type: ignore