mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Torch onnx (#48980)
Summary: Fixes https://github.com/pytorch/pytorch/issues/45215 This is a follow up PR of https://github.com/pytorch/pytorch/issues/45258 and https://github.com/pytorch/pytorch/issues/48782 Pull Request resolved: https://github.com/pytorch/pytorch/pull/48980 Reviewed By: zhangguanheng66 Differential Revision: D25399823 Pulled By: ezyang fbshipit-source-id: 798055f4abbbffecdfab0325884193c81addecec
This commit is contained in:
parent
5450614cf6
commit
34cc77a811
24
mypy.ini
24
mypy.ini
|
|
@ -143,30 +143,6 @@ ignore_errors = True
|
||||||
[mypy-torch.nn.intrinsic.qat.modules.conv_fused]
|
[mypy-torch.nn.intrinsic.qat.modules.conv_fused]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-torch.onnx.operators]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.onnx.symbolic_opset8]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.onnx.symbolic_opset9]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.onnx.symbolic_opset11]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.onnx.symbolic_caffe2]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.onnx.symbolic_helper]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.onnx.symbolic_registry]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.onnx.utils]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.multiprocessing.pool]
|
[mypy-torch.multiprocessing.pool]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -165,7 +165,10 @@ def wait(fut: Future) -> Any: ...
|
||||||
def _collect_all(futures: List[Future]) -> Future: ...
|
def _collect_all(futures: List[Future]) -> Future: ...
|
||||||
|
|
||||||
def unify_type_list(types: List[JitType]) -> JitType: ...
|
def unify_type_list(types: List[JitType]) -> JitType: ...
|
||||||
def _freeze_module(module: ScriptModule, preserved_attrs: List[str], freeze_interfaces: _bool = True) -> ScriptModule: ...
|
def _freeze_module(module: ScriptModule,
|
||||||
|
preserved_attrs: List[str] = [],
|
||||||
|
freeze_interfaces: _bool = True,
|
||||||
|
preserveParameters: _bool = True) -> ScriptModule: ...
|
||||||
def _is_tracing() -> _bool: ...
|
def _is_tracing() -> _bool: ...
|
||||||
def _jit_init() -> _bool: ...
|
def _jit_init() -> _bool: ...
|
||||||
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
|
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
|
||||||
|
|
@ -217,6 +220,8 @@ def _jit_get_trigger_value(trigger_name: str) -> _int: ...
|
||||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||||
ResolutionCallback = Callable[[str], Callable[..., Any]]
|
ResolutionCallback = Callable[[str], Callable[..., Any]]
|
||||||
|
|
||||||
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||||
|
# and torch/csrc/jit/python/init.cpp
|
||||||
def _create_function_from_graph(qualname: str, graph: Graph) -> Graph: ...
|
def _create_function_from_graph(qualname: str, graph: Graph) -> Graph: ...
|
||||||
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
|
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
|
||||||
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
|
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
|
||||||
|
|
@ -246,6 +251,55 @@ def _resolve_type_from_object(obj: Any, range: SourceRange, rcb: ResolutionCallb
|
||||||
def _create_module_with_type(ty: JitType) -> ScriptModule: ...
|
def _create_module_with_type(ty: JitType) -> ScriptModule: ...
|
||||||
def _run_emit_module_hook(m: ScriptModule): ...
|
def _run_emit_module_hook(m: ScriptModule): ...
|
||||||
def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def, new_name: str) -> Def: ...
|
def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def, new_name: str) -> Def: ...
|
||||||
|
|
||||||
|
def _jit_pass_lower_all_tuples(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ...
|
||||||
|
def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, opset_version: _int) -> None: ...
|
||||||
|
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], onnx_shape_inference: _bool = False) -> None: ...
|
||||||
|
def _jit_pass_fixup_onnx_loop_node_inputs(n: Node) -> None: ...
|
||||||
|
def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ...
|
||||||
|
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_onnx_preprocess(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_onnx_prepare_inplace_ops_for_onnx(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_onnx_remove_print(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_onnx_unpack_quantized_weights(
|
||||||
|
graph: Graph,
|
||||||
|
paramsDict: Dict[str, IValue]
|
||||||
|
) -> Dict[str, IValue]: ...
|
||||||
|
def _jit_pass_onnx_quantization_insert_permutes(
|
||||||
|
graph: Graph,
|
||||||
|
paramsDict: Dict[str, IValue]
|
||||||
|
) -> Dict[str, IValue]: ...
|
||||||
|
def _jit_pass_custom_pattern_based_rewrite_graph(pattern: str, fused_node_name: str, graph: Graph) -> None: ...
|
||||||
|
def _jit_onnx_list_model_parameters(module: ScriptModule) -> Tuple[ScriptModule, List[IValue]]: ...
|
||||||
|
def _jit_pass_erase_number_types(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_onnx(graph: Graph, _jit_pass_onnx: _onnx.OperatorExportTypes) -> Graph: ...
|
||||||
|
def _jit_pass_onnx_scalar_type_analysis(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_onnx_peephole(graph: Graph, opset_version: _int, fixed_batch_size: _bool) -> None: ...
|
||||||
|
def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ...
|
||||||
|
def _jit_pass_inline_fork_wait(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
|
||||||
|
def _jit_pass_onnx_constant_fold(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> Dict[str, IValue]: ...
|
||||||
|
def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
|
||||||
|
def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_filter_non_tensor_arguments(params: Dict[str, IValue]) -> Dict[str, Tensor]: ...
|
||||||
|
def _jit_decay_packed_param_input_types(graph: Graph) -> None: ...
|
||||||
|
def _jit_pass_onnx_node_shape_type_inference(n: Node, opset_version: _int) -> None: ...
|
||||||
|
def _jit_pass_onnx_block(
|
||||||
|
old_block: Block,
|
||||||
|
new_block: Block,
|
||||||
|
operator_export_type: _onnx.OperatorExportTypes,
|
||||||
|
env: Dict[Value, Value]
|
||||||
|
) -> None: ...
|
||||||
|
def _jit_pass_fixup_onnx_controlflow_node(n: Node, opset_version: _int) -> Node: ...
|
||||||
|
|
||||||
def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ...
|
def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ...
|
||||||
def _jit_script_compile_overload(
|
def _jit_script_compile_overload(
|
||||||
qualname: str,
|
qualname: str,
|
||||||
|
|
@ -281,8 +335,18 @@ def import_ir_module_from_buffer(
|
||||||
extra_files: Dict[str, Any]
|
extra_files: Dict[str, Any]
|
||||||
) -> ScriptModule: ...
|
) -> ScriptModule: ...
|
||||||
|
|
||||||
|
def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ...
|
||||||
|
def _check_onnx_proto(proto: str) -> None: ...
|
||||||
|
def _propagate_and_assign_input_shapes(
|
||||||
|
graph: Graph,
|
||||||
|
inputs: Tuple[Tensor, ...],
|
||||||
|
with_grad: _bool,
|
||||||
|
propagate: _bool
|
||||||
|
) -> Graph: ...
|
||||||
|
|
||||||
# Defined in torch/torch/csrc/jit/ir/ir.h
|
# Defined in torch/torch/csrc/jit/ir/ir.h
|
||||||
class Graph:
|
class Graph:
|
||||||
|
def eraseInput(self, i: _int) -> None: ...
|
||||||
...
|
...
|
||||||
|
|
||||||
# Defined in torch/csrc/jit/ir/ir.h
|
# Defined in torch/csrc/jit/ir/ir.h
|
||||||
|
|
@ -366,8 +430,8 @@ class ScriptFunction:
|
||||||
def qualified_name(self) -> str: ...
|
def qualified_name(self) -> str: ...
|
||||||
|
|
||||||
class ScriptMethod:
|
class ScriptMethod:
|
||||||
|
graph: Graph
|
||||||
...
|
...
|
||||||
|
|
||||||
class ModuleDict:
|
class ModuleDict:
|
||||||
def __init__(self, mod: ScriptModule) -> None: ...
|
def __init__(self, mod: ScriptModule) -> None: ...
|
||||||
def items(self) -> List[Tuple[str, Any]]: ...
|
def items(self) -> List[Tuple[str, Any]]: ...
|
||||||
|
|
@ -378,6 +442,10 @@ class ParameterDict:
|
||||||
class BufferDict:
|
class BufferDict:
|
||||||
def __init__(self, mod: ScriptModule) -> None: ...
|
def __init__(self, mod: ScriptModule) -> None: ...
|
||||||
|
|
||||||
|
# Defined in torch/csrc/jit/api/module.h
|
||||||
|
class Module:
|
||||||
|
...
|
||||||
|
|
||||||
# Defined in torch/csrc/Module.cpp
|
# Defined in torch/csrc/Module.cpp
|
||||||
def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension
|
def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension
|
||||||
def _autograd_init() -> _bool: ... # THPAutograd_initExtension
|
def _autograd_init() -> _bool: ... # THPAutograd_initExtension
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ class OperatorExportTypes(Enum):
|
||||||
ONNX_ATEN = ...
|
ONNX_ATEN = ...
|
||||||
ONNX_ATEN_FALLBACK = ...
|
ONNX_ATEN_FALLBACK = ...
|
||||||
RAW = ...
|
RAW = ...
|
||||||
|
ONNX_FALLTHROUGH = ...
|
||||||
|
|
||||||
class TrainingMode(Enum):
|
class TrainingMode(Enum):
|
||||||
EVAL = ...
|
EVAL = ...
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
from sys import maxsize as maxsize
|
from sys import maxsize as maxsize
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
import torch.onnx
|
import torch.onnx
|
||||||
# This import monkey-patches graph manipulation methods on Graph, used for the
|
# This import monkey-patches graph manipulation methods on Graph, used for the
|
||||||
|
|
@ -125,7 +126,7 @@ def parse_args(*arg_descriptors):
|
||||||
def wrapper(g, *args, **kwargs):
|
def wrapper(g, *args, **kwargs):
|
||||||
# some args may be optional, so the length may be smaller
|
# some args may be optional, so the length may be smaller
|
||||||
assert len(arg_descriptors) >= len(args)
|
assert len(arg_descriptors) >= len(args)
|
||||||
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
|
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] # type: ignore
|
||||||
# only support _outputs in kwargs
|
# only support _outputs in kwargs
|
||||||
assert len(kwargs) <= 1
|
assert len(kwargs) <= 1
|
||||||
if len(kwargs) == 1:
|
if len(kwargs) == 1:
|
||||||
|
|
@ -232,18 +233,18 @@ def _select_helper(g, self, dim, index, apply_reshape=True):
|
||||||
|
|
||||||
def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
|
def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
|
||||||
if _export_onnx_opset_version <= 9:
|
if _export_onnx_opset_version <= 9:
|
||||||
from torch.onnx.symbolic_opset9 import _slice
|
from torch.onnx.symbolic_opset9 import _slice as _slice9
|
||||||
return _slice(g, input, axes, starts, ends)
|
return _slice9(g, input, axes, starts, ends)
|
||||||
else:
|
else:
|
||||||
from torch.onnx.symbolic_opset10 import _slice
|
from torch.onnx.symbolic_opset10 import _slice as _slice10
|
||||||
return _slice(g, input, axes, starts, ends, steps, dynamic_slice)
|
return _slice10(g, input, axes, starts, ends, steps, dynamic_slice)
|
||||||
|
|
||||||
def _hardtanh_helper(g, input, min_val, max_val):
|
def _hardtanh_helper(g, input, min_val, max_val):
|
||||||
if _export_onnx_opset_version <= 10:
|
if _export_onnx_opset_version <= 10:
|
||||||
from torch.onnx.symbolic_opset9 import hardtanh
|
from torch.onnx.symbolic_opset9 import hardtanh
|
||||||
return hardtanh(g, input, min_val, max_val)
|
return hardtanh(g, input, min_val, max_val)
|
||||||
else:
|
else:
|
||||||
from torch.onnx.symbolic_opset11 import hardtanh
|
from torch.onnx.symbolic_opset11 import hardtanh # type: ignore[no-redef]
|
||||||
return hardtanh(g, input, min_val, max_val)
|
return hardtanh(g, input, min_val, max_val)
|
||||||
|
|
||||||
def _is_fp(value):
|
def _is_fp(value):
|
||||||
|
|
@ -380,7 +381,7 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode , align_
|
||||||
size = g.op("Concat", *size, axis_i=0)
|
size = g.op("Concat", *size, axis_i=0)
|
||||||
scale_factor = _interpolate_size_to_scales(g, input, size, dim)
|
scale_factor = _interpolate_size_to_scales(g, input, size, dim)
|
||||||
else:
|
else:
|
||||||
return _unimplemented("Both size and scales are None in __interpolate")
|
return _unimplemented("interpolate", "Both size and scales are None in __interpolate")
|
||||||
return scale_factor, mode
|
return scale_factor, mode
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -388,7 +389,7 @@ def _unbind_helper(g, self, dim, _outputs):
|
||||||
if _export_onnx_opset_version <= 9:
|
if _export_onnx_opset_version <= 9:
|
||||||
from torch.onnx.symbolic_opset9 import unbind
|
from torch.onnx.symbolic_opset9 import unbind
|
||||||
else:
|
else:
|
||||||
from torch.onnx.symbolic_opset11 import unbind
|
from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef]
|
||||||
return unbind(g, self, dim, _outputs)
|
return unbind(g, self, dim, _outputs)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -396,7 +397,8 @@ def _scatter_helper(g, self, dim, index, src):
|
||||||
if _export_onnx_opset_version <= 10:
|
if _export_onnx_opset_version <= 10:
|
||||||
from torch.onnx.symbolic_opset9 import scatter
|
from torch.onnx.symbolic_opset9 import scatter
|
||||||
else:
|
else:
|
||||||
from torch.onnx.symbolic_opset11 import scatter
|
# for mypy, scatter was imported two lines above
|
||||||
|
from torch.onnx.symbolic_opset11 import scatter # type: ignore
|
||||||
return scatter(g, self, dim, index, src)
|
return scatter(g, self, dim, index, src)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -444,7 +446,8 @@ def _index_fill_reshape_helper(g, self, dim, index):
|
||||||
if _export_onnx_opset_version <= 10:
|
if _export_onnx_opset_version <= 10:
|
||||||
from torch.onnx.symbolic_opset9 import scatter
|
from torch.onnx.symbolic_opset9 import scatter
|
||||||
else:
|
else:
|
||||||
from torch.onnx.symbolic_opset11 import scatter
|
# for mypy, scatter was imported two lines above
|
||||||
|
from torch.onnx.symbolic_opset11 import scatter # type: ignore
|
||||||
|
|
||||||
if self.type().dim() is None:
|
if self.type().dim() is None:
|
||||||
return _unimplemented("index_fill", "input rank not accesible")
|
return _unimplemented("index_fill", "input rank not accesible")
|
||||||
|
|
@ -632,4 +635,4 @@ scalar_type_to_onnx = [
|
||||||
|
|
||||||
# Global set to store the list of quantized operators in the network.
|
# Global set to store the list of quantized operators in the network.
|
||||||
# This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX.
|
# This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX.
|
||||||
_quantized_ops = set()
|
_quantized_ops: Set[int] = set()
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import torch.onnx.symbolic_helper as sym_help
|
||||||
import torch.onnx.symbolic_opset9 as sym_opset9
|
import torch.onnx.symbolic_opset9 as sym_opset9
|
||||||
|
|
||||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type
|
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type
|
||||||
from torch.onnx.symbolic_opset9 import _cast_Float
|
from torch.onnx.symbolic_opset9 import _cast_Float # type: ignore
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@ from functools import wraps
|
||||||
import torch.onnx.symbolic_helper as sym_help
|
import torch.onnx.symbolic_helper as sym_help
|
||||||
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented
|
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -311,7 +313,7 @@ def _maybe_cast_reduce_op_input(g, self):
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
# pytorch reduce-ops cast all other integral types to int64
|
# pytorch reduce-ops cast all other integral types to int64
|
||||||
if not sym_help._is_fp(self) and not (dtype == 'Long'):
|
if not sym_help._is_fp(self) and not (dtype == 'Long'):
|
||||||
self = _cast_Long(g, self, False)
|
self = _cast_Long(g, self, False) # type: ignore
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2092,7 +2094,7 @@ def _pack_padded_sequence(g, input, lengths, batch_first):
|
||||||
# It's really only necessary because those operators expand to something that
|
# It's really only necessary because those operators expand to something that
|
||||||
# only works with int32 types in Caffe2...
|
# only works with int32 types in Caffe2...
|
||||||
if lengths.type().scalarType() != 'Int':
|
if lengths.type().scalarType() != 'Int':
|
||||||
lengths = _cast_Int(g, lengths, False)
|
lengths = _cast_Int(g, lengths, False) # type: ignore
|
||||||
return g.op("prim::PackPadded", input, lengths, outputs=2)
|
return g.op("prim::PackPadded", input, lengths, outputs=2)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2436,7 +2438,7 @@ def arange(g, *args):
|
||||||
|
|
||||||
|
|
||||||
def masked_fill(g, self, mask, value):
|
def masked_fill(g, self, mask, value):
|
||||||
mask = _cast_Bool(g, mask, False)
|
mask = _cast_Bool(g, mask, False) # type: ignore
|
||||||
value = sym_help._maybe_get_scalar(value)
|
value = sym_help._maybe_get_scalar(value)
|
||||||
return g.op('Where', mask, sym_help._if_scalar_type_as(g, value, self), self)
|
return g.op('Where', mask, sym_help._if_scalar_type_as(g, value, self), self)
|
||||||
|
|
||||||
|
|
@ -2734,6 +2736,7 @@ def as_strided(g, self, sizes, strides, offset=None):
|
||||||
sizes = sym_help._maybe_get_const(sizes, 'is')
|
sizes = sym_help._maybe_get_const(sizes, 'is')
|
||||||
rank = len(strides)
|
rank = len(strides)
|
||||||
self_1d = g.op("Reshape", self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
|
self_1d = g.op("Reshape", self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
|
||||||
|
ind: Optional[torch.Tensor]
|
||||||
if not sym_help._is_value(sizes):
|
if not sym_help._is_value(sizes):
|
||||||
ind = torch.tensor([0], dtype=torch.long)
|
ind = torch.tensor([0], dtype=torch.long)
|
||||||
for i, (size, stride) in enumerate(zip(sizes, strides)):
|
for i, (size, stride) in enumerate(zip(sizes, strides)):
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import warnings
|
import warnings
|
||||||
import importlib
|
import importlib
|
||||||
from inspect import getmembers, isfunction
|
from inspect import getmembers, isfunction
|
||||||
|
from typing import Dict, Tuple, Any, Union
|
||||||
|
|
||||||
# The symbolic registry "_registry" is a dictionary that maps operators
|
# The symbolic registry "_registry" is a dictionary that maps operators
|
||||||
# (for a specific domain and opset version) to their symbolic functions.
|
# (for a specific domain and opset version) to their symbolic functions.
|
||||||
|
|
@ -8,9 +9,9 @@ from inspect import getmembers, isfunction
|
||||||
# The keys are tuples (domain, version), (where domain is a string, and version is an int),
|
# The keys are tuples (domain, version), (where domain is a string, and version is an int),
|
||||||
# and the operator's name (string).
|
# and the operator's name (string).
|
||||||
# The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
|
# The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
|
||||||
_registry = {}
|
_registry: Dict[Tuple[str, int], Dict] = {}
|
||||||
|
|
||||||
_symbolic_versions = {}
|
_symbolic_versions: Dict[Union[int, str], Any] = {}
|
||||||
from torch.onnx.symbolic_helper import _onnx_stable_opsets
|
from torch.onnx.symbolic_helper import _onnx_stable_opsets
|
||||||
for opset_version in _onnx_stable_opsets:
|
for opset_version in _onnx_stable_opsets:
|
||||||
module = importlib.import_module('torch.onnx.symbolic_opset{}'.format(opset_version))
|
module = importlib.import_module('torch.onnx.symbolic_opset{}'.format(opset_version))
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from torch._six import string_classes
|
||||||
from torch.jit import _unique_state_dict
|
from torch.jit import _unique_state_dict
|
||||||
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode
|
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode
|
||||||
from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _check_onnx_proto
|
from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _check_onnx_proto
|
||||||
|
from typing import Union, Tuple, List
|
||||||
|
|
||||||
|
|
||||||
# the flag to tell the user whether it's in the middle of ONNX export or not
|
# the flag to tell the user whether it's in the middle of ONNX export or not
|
||||||
|
|
@ -76,7 +77,7 @@ def export(model, args, f, export_params=True, verbose=False, training=None,
|
||||||
if aten or export_raw_ir:
|
if aten or export_raw_ir:
|
||||||
assert operator_export_type is None
|
assert operator_export_type is None
|
||||||
assert aten ^ export_raw_ir
|
assert aten ^ export_raw_ir
|
||||||
operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW
|
operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW
|
||||||
elif operator_export_type is None:
|
elif operator_export_type is None:
|
||||||
if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
|
if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
|
||||||
operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
|
operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||||
|
|
@ -351,6 +352,7 @@ def _trace_and_get_graph_from_model(model, args):
|
||||||
|
|
||||||
def _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes):
|
def _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes):
|
||||||
torch_out = None
|
torch_out = None
|
||||||
|
params: Union[List, Tuple]
|
||||||
if isinstance(model, torch.jit.ScriptModule):
|
if isinstance(model, torch.jit.ScriptModule):
|
||||||
try:
|
try:
|
||||||
graph = model.forward.graph
|
graph = model.forward.graph
|
||||||
|
|
@ -442,7 +444,7 @@ def _model_to_graph(model, args, verbose=False,
|
||||||
param_names = input_and_param_names[len(input_and_param_names) - len(params):]
|
param_names = input_and_param_names[len(input_and_param_names) - len(params):]
|
||||||
params_dict = dict(zip(param_names, params))
|
params_dict = dict(zip(param_names, params))
|
||||||
|
|
||||||
if training is None or training == TrainingMode.EVAL or (training == TrainingMode.PRESERVE and not is_originally_training):
|
if training is None or training == TrainingMode.EVAL:
|
||||||
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
|
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
|
||||||
|
|
||||||
if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
|
if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
|
||||||
|
|
@ -476,7 +478,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
|
||||||
if aten or export_raw_ir:
|
if aten or export_raw_ir:
|
||||||
assert operator_export_type is None
|
assert operator_export_type is None
|
||||||
assert aten ^ export_raw_ir
|
assert aten ^ export_raw_ir
|
||||||
operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW
|
operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW
|
||||||
elif operator_export_type is None:
|
elif operator_export_type is None:
|
||||||
operator_export_type = OperatorExportTypes.ONNX
|
operator_export_type = OperatorExportTypes.ONNX
|
||||||
return _export_to_pretty_string(model, args, f, export_params, verbose, training,
|
return _export_to_pretty_string(model, args, f, export_params, verbose, training,
|
||||||
|
|
@ -1051,6 +1053,10 @@ def _graph_constant(g, value, dims, type, *args, **kwargs):
|
||||||
dims = [1]
|
dims = [1]
|
||||||
isscalar = True
|
isscalar = True
|
||||||
type = type.lower()
|
type = type.lower()
|
||||||
|
tensor: Union[torch.CharTensor, torch.ShortTensor,
|
||||||
|
torch.IntTensor, torch.LongTensor,
|
||||||
|
torch.HalfTensor, torch.FloatTensor,
|
||||||
|
torch.DoubleTensor]
|
||||||
if type == "char":
|
if type == "char":
|
||||||
tensor = torch.CharTensor(*dims)
|
tensor = torch.CharTensor(*dims)
|
||||||
elif type == "short":
|
elif type == "short":
|
||||||
|
|
@ -1068,7 +1074,7 @@ def _graph_constant(g, value, dims, type, *args, **kwargs):
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown type, type should be one of the following strings: "
|
raise ValueError("Unknown type, type should be one of the following strings: "
|
||||||
"char, short, int, long, half, float, double")
|
"char, short, int, long, half, float, double")
|
||||||
tensor.fill_(value)
|
tensor.fill_(value) # type: ignore
|
||||||
if isscalar:
|
if isscalar:
|
||||||
return g.op("Constant", *args, value_z=tensor, **kwargs)
|
return g.op("Constant", *args, value_z=tensor, **kwargs)
|
||||||
return g.op("Constant", *args, value_t=tensor, **kwargs)
|
return g.op("Constant", *args, value_t=tensor, **kwargs)
|
||||||
|
|
@ -1141,8 +1147,8 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
|
||||||
dynamic_axes[key] = value_dict
|
dynamic_axes[key] = value_dict
|
||||||
|
|
||||||
|
|
||||||
torch._C.Graph.op = _graph_op
|
torch._C.Graph.op = _graph_op # type: ignore
|
||||||
torch._C.Graph.at = _graph_at
|
torch._C.Graph.at = _graph_at # type: ignore
|
||||||
torch._C.Block.op = _block_op
|
torch._C.Block.op = _block_op # type: ignore
|
||||||
torch._C.Graph.constant = _graph_constant
|
torch._C.Graph.constant = _graph_constant # type: ignore
|
||||||
torch._C.Node.__getitem__ = _node_getitem
|
torch._C.Node.__getitem__ = _node_getitem # type: ignore
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user