mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Torch onnx (#48980)
Summary: Fixes https://github.com/pytorch/pytorch/issues/45215 This is a follow up PR of https://github.com/pytorch/pytorch/issues/45258 and https://github.com/pytorch/pytorch/issues/48782 Pull Request resolved: https://github.com/pytorch/pytorch/pull/48980 Reviewed By: zhangguanheng66 Differential Revision: D25399823 Pulled By: ezyang fbshipit-source-id: 798055f4abbbffecdfab0325884193c81addecec
This commit is contained in:
parent
5450614cf6
commit
34cc77a811
24
mypy.ini
24
mypy.ini
|
|
@ -143,30 +143,6 @@ ignore_errors = True
|
|||
[mypy-torch.nn.intrinsic.qat.modules.conv_fused]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.onnx.operators]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.onnx.symbolic_opset8]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.onnx.symbolic_opset9]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.onnx.symbolic_opset11]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.onnx.symbolic_caffe2]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.onnx.symbolic_helper]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.onnx.symbolic_registry]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.onnx.utils]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.multiprocessing.pool]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -165,7 +165,10 @@ def wait(fut: Future) -> Any: ...
|
|||
def _collect_all(futures: List[Future]) -> Future: ...
|
||||
|
||||
def unify_type_list(types: List[JitType]) -> JitType: ...
|
||||
def _freeze_module(module: ScriptModule, preserved_attrs: List[str], freeze_interfaces: _bool = True) -> ScriptModule: ...
|
||||
def _freeze_module(module: ScriptModule,
|
||||
preserved_attrs: List[str] = [],
|
||||
freeze_interfaces: _bool = True,
|
||||
preserveParameters: _bool = True) -> ScriptModule: ...
|
||||
def _is_tracing() -> _bool: ...
|
||||
def _jit_init() -> _bool: ...
|
||||
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
|
||||
|
|
@ -217,6 +220,8 @@ def _jit_get_trigger_value(trigger_name: str) -> _int: ...
|
|||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
ResolutionCallback = Callable[[str], Callable[..., Any]]
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
# and torch/csrc/jit/python/init.cpp
|
||||
def _create_function_from_graph(qualname: str, graph: Graph) -> Graph: ...
|
||||
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
|
||||
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
|
||||
|
|
@ -246,6 +251,55 @@ def _resolve_type_from_object(obj: Any, range: SourceRange, rcb: ResolutionCallb
|
|||
def _create_module_with_type(ty: JitType) -> ScriptModule: ...
|
||||
def _run_emit_module_hook(m: ScriptModule): ...
|
||||
def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def, new_name: str) -> Def: ...
|
||||
|
||||
def _jit_pass_lower_all_tuples(graph: Graph) -> None: ...
|
||||
def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ...
|
||||
def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, opset_version: _int) -> None: ...
|
||||
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], onnx_shape_inference: _bool = False) -> None: ...
|
||||
def _jit_pass_fixup_onnx_loop_node_inputs(n: Node) -> None: ...
|
||||
def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph) -> None: ...
|
||||
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
|
||||
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
|
||||
def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ...
|
||||
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
|
||||
def _jit_pass_onnx_preprocess(graph: Graph) -> None: ...
|
||||
def _jit_pass_onnx_prepare_inplace_ops_for_onnx(graph: Graph) -> None: ...
|
||||
def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ...
|
||||
def _jit_pass_onnx_remove_print(graph: Graph) -> None: ...
|
||||
def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ...
|
||||
def _jit_pass_onnx_unpack_quantized_weights(
|
||||
graph: Graph,
|
||||
paramsDict: Dict[str, IValue]
|
||||
) -> Dict[str, IValue]: ...
|
||||
def _jit_pass_onnx_quantization_insert_permutes(
|
||||
graph: Graph,
|
||||
paramsDict: Dict[str, IValue]
|
||||
) -> Dict[str, IValue]: ...
|
||||
def _jit_pass_custom_pattern_based_rewrite_graph(pattern: str, fused_node_name: str, graph: Graph) -> None: ...
|
||||
def _jit_onnx_list_model_parameters(module: ScriptModule) -> Tuple[ScriptModule, List[IValue]]: ...
|
||||
def _jit_pass_erase_number_types(graph: Graph) -> None: ...
|
||||
def _jit_pass_onnx(graph: Graph, _jit_pass_onnx: _onnx.OperatorExportTypes) -> Graph: ...
|
||||
def _jit_pass_onnx_scalar_type_analysis(graph: Graph) -> None: ...
|
||||
def _jit_pass_onnx_peephole(graph: Graph, opset_version: _int, fixed_batch_size: _bool) -> None: ...
|
||||
def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ...
|
||||
def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ...
|
||||
def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ...
|
||||
def _jit_pass_inline_fork_wait(graph: Graph) -> None: ...
|
||||
def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
|
||||
def _jit_pass_onnx_constant_fold(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> Dict[str, IValue]: ...
|
||||
def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
|
||||
def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ...
|
||||
def _jit_pass_filter_non_tensor_arguments(params: Dict[str, IValue]) -> Dict[str, Tensor]: ...
|
||||
def _jit_decay_packed_param_input_types(graph: Graph) -> None: ...
|
||||
def _jit_pass_onnx_node_shape_type_inference(n: Node, opset_version: _int) -> None: ...
|
||||
def _jit_pass_onnx_block(
|
||||
old_block: Block,
|
||||
new_block: Block,
|
||||
operator_export_type: _onnx.OperatorExportTypes,
|
||||
env: Dict[Value, Value]
|
||||
) -> None: ...
|
||||
def _jit_pass_fixup_onnx_controlflow_node(n: Node, opset_version: _int) -> Node: ...
|
||||
|
||||
def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ...
|
||||
def _jit_script_compile_overload(
|
||||
qualname: str,
|
||||
|
|
@ -281,8 +335,18 @@ def import_ir_module_from_buffer(
|
|||
extra_files: Dict[str, Any]
|
||||
) -> ScriptModule: ...
|
||||
|
||||
def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ...
|
||||
def _check_onnx_proto(proto: str) -> None: ...
|
||||
def _propagate_and_assign_input_shapes(
|
||||
graph: Graph,
|
||||
inputs: Tuple[Tensor, ...],
|
||||
with_grad: _bool,
|
||||
propagate: _bool
|
||||
) -> Graph: ...
|
||||
|
||||
# Defined in torch/torch/csrc/jit/ir/ir.h
|
||||
class Graph:
|
||||
def eraseInput(self, i: _int) -> None: ...
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/jit/ir/ir.h
|
||||
|
|
@ -366,8 +430,8 @@ class ScriptFunction:
|
|||
def qualified_name(self) -> str: ...
|
||||
|
||||
class ScriptMethod:
|
||||
graph: Graph
|
||||
...
|
||||
|
||||
class ModuleDict:
|
||||
def __init__(self, mod: ScriptModule) -> None: ...
|
||||
def items(self) -> List[Tuple[str, Any]]: ...
|
||||
|
|
@ -378,6 +442,10 @@ class ParameterDict:
|
|||
class BufferDict:
|
||||
def __init__(self, mod: ScriptModule) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/api/module.h
|
||||
class Module:
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/Module.cpp
|
||||
def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension
|
||||
def _autograd_init() -> _bool: ... # THPAutograd_initExtension
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ class OperatorExportTypes(Enum):
|
|||
ONNX_ATEN = ...
|
||||
ONNX_ATEN_FALLBACK = ...
|
||||
RAW = ...
|
||||
ONNX_FALLTHROUGH = ...
|
||||
|
||||
class TrainingMode(Enum):
|
||||
EVAL = ...
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
import torch
|
||||
import warnings
|
||||
from sys import maxsize as maxsize
|
||||
from typing import Set
|
||||
|
||||
import torch.onnx
|
||||
# This import monkey-patches graph manipulation methods on Graph, used for the
|
||||
|
|
@ -125,7 +126,7 @@ def parse_args(*arg_descriptors):
|
|||
def wrapper(g, *args, **kwargs):
|
||||
# some args may be optional, so the length may be smaller
|
||||
assert len(arg_descriptors) >= len(args)
|
||||
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
|
||||
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] # type: ignore
|
||||
# only support _outputs in kwargs
|
||||
assert len(kwargs) <= 1
|
||||
if len(kwargs) == 1:
|
||||
|
|
@ -232,18 +233,18 @@ def _select_helper(g, self, dim, index, apply_reshape=True):
|
|||
|
||||
def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
|
||||
if _export_onnx_opset_version <= 9:
|
||||
from torch.onnx.symbolic_opset9 import _slice
|
||||
return _slice(g, input, axes, starts, ends)
|
||||
from torch.onnx.symbolic_opset9 import _slice as _slice9
|
||||
return _slice9(g, input, axes, starts, ends)
|
||||
else:
|
||||
from torch.onnx.symbolic_opset10 import _slice
|
||||
return _slice(g, input, axes, starts, ends, steps, dynamic_slice)
|
||||
from torch.onnx.symbolic_opset10 import _slice as _slice10
|
||||
return _slice10(g, input, axes, starts, ends, steps, dynamic_slice)
|
||||
|
||||
def _hardtanh_helper(g, input, min_val, max_val):
|
||||
if _export_onnx_opset_version <= 10:
|
||||
from torch.onnx.symbolic_opset9 import hardtanh
|
||||
return hardtanh(g, input, min_val, max_val)
|
||||
else:
|
||||
from torch.onnx.symbolic_opset11 import hardtanh
|
||||
from torch.onnx.symbolic_opset11 import hardtanh # type: ignore[no-redef]
|
||||
return hardtanh(g, input, min_val, max_val)
|
||||
|
||||
def _is_fp(value):
|
||||
|
|
@ -380,7 +381,7 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode , align_
|
|||
size = g.op("Concat", *size, axis_i=0)
|
||||
scale_factor = _interpolate_size_to_scales(g, input, size, dim)
|
||||
else:
|
||||
return _unimplemented("Both size and scales are None in __interpolate")
|
||||
return _unimplemented("interpolate", "Both size and scales are None in __interpolate")
|
||||
return scale_factor, mode
|
||||
|
||||
|
||||
|
|
@ -388,7 +389,7 @@ def _unbind_helper(g, self, dim, _outputs):
|
|||
if _export_onnx_opset_version <= 9:
|
||||
from torch.onnx.symbolic_opset9 import unbind
|
||||
else:
|
||||
from torch.onnx.symbolic_opset11 import unbind
|
||||
from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef]
|
||||
return unbind(g, self, dim, _outputs)
|
||||
|
||||
|
||||
|
|
@ -396,7 +397,8 @@ def _scatter_helper(g, self, dim, index, src):
|
|||
if _export_onnx_opset_version <= 10:
|
||||
from torch.onnx.symbolic_opset9 import scatter
|
||||
else:
|
||||
from torch.onnx.symbolic_opset11 import scatter
|
||||
# for mypy, scatter was imported two lines above
|
||||
from torch.onnx.symbolic_opset11 import scatter # type: ignore
|
||||
return scatter(g, self, dim, index, src)
|
||||
|
||||
|
||||
|
|
@ -444,7 +446,8 @@ def _index_fill_reshape_helper(g, self, dim, index):
|
|||
if _export_onnx_opset_version <= 10:
|
||||
from torch.onnx.symbolic_opset9 import scatter
|
||||
else:
|
||||
from torch.onnx.symbolic_opset11 import scatter
|
||||
# for mypy, scatter was imported two lines above
|
||||
from torch.onnx.symbolic_opset11 import scatter # type: ignore
|
||||
|
||||
if self.type().dim() is None:
|
||||
return _unimplemented("index_fill", "input rank not accesible")
|
||||
|
|
@ -632,4 +635,4 @@ scalar_type_to_onnx = [
|
|||
|
||||
# Global set to store the list of quantized operators in the network.
|
||||
# This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX.
|
||||
_quantized_ops = set()
|
||||
_quantized_ops: Set[int] = set()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch.onnx.symbolic_helper as sym_help
|
|||
import torch.onnx.symbolic_opset9 as sym_opset9
|
||||
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type
|
||||
from torch.onnx.symbolic_opset9 import _cast_Float
|
||||
from torch.onnx.symbolic_opset9 import _cast_Float # type: ignore
|
||||
|
||||
import warnings
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from functools import wraps
|
|||
import torch.onnx.symbolic_helper as sym_help
|
||||
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import math
|
||||
import warnings
|
||||
|
|
@ -311,7 +313,7 @@ def _maybe_cast_reduce_op_input(g, self):
|
|||
if dtype is not None:
|
||||
# pytorch reduce-ops cast all other integral types to int64
|
||||
if not sym_help._is_fp(self) and not (dtype == 'Long'):
|
||||
self = _cast_Long(g, self, False)
|
||||
self = _cast_Long(g, self, False) # type: ignore
|
||||
return self
|
||||
|
||||
|
||||
|
|
@ -2092,7 +2094,7 @@ def _pack_padded_sequence(g, input, lengths, batch_first):
|
|||
# It's really only necessary because those operators expand to something that
|
||||
# only works with int32 types in Caffe2...
|
||||
if lengths.type().scalarType() != 'Int':
|
||||
lengths = _cast_Int(g, lengths, False)
|
||||
lengths = _cast_Int(g, lengths, False) # type: ignore
|
||||
return g.op("prim::PackPadded", input, lengths, outputs=2)
|
||||
|
||||
|
||||
|
|
@ -2436,7 +2438,7 @@ def arange(g, *args):
|
|||
|
||||
|
||||
def masked_fill(g, self, mask, value):
|
||||
mask = _cast_Bool(g, mask, False)
|
||||
mask = _cast_Bool(g, mask, False) # type: ignore
|
||||
value = sym_help._maybe_get_scalar(value)
|
||||
return g.op('Where', mask, sym_help._if_scalar_type_as(g, value, self), self)
|
||||
|
||||
|
|
@ -2734,6 +2736,7 @@ def as_strided(g, self, sizes, strides, offset=None):
|
|||
sizes = sym_help._maybe_get_const(sizes, 'is')
|
||||
rank = len(strides)
|
||||
self_1d = g.op("Reshape", self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
|
||||
ind: Optional[torch.Tensor]
|
||||
if not sym_help._is_value(sizes):
|
||||
ind = torch.tensor([0], dtype=torch.long)
|
||||
for i, (size, stride) in enumerate(zip(sizes, strides)):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import warnings
|
||||
import importlib
|
||||
from inspect import getmembers, isfunction
|
||||
from typing import Dict, Tuple, Any, Union
|
||||
|
||||
# The symbolic registry "_registry" is a dictionary that maps operators
|
||||
# (for a specific domain and opset version) to their symbolic functions.
|
||||
|
|
@ -8,9 +9,9 @@ from inspect import getmembers, isfunction
|
|||
# The keys are tuples (domain, version), (where domain is a string, and version is an int),
|
||||
# and the operator's name (string).
|
||||
# The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
|
||||
_registry = {}
|
||||
_registry: Dict[Tuple[str, int], Dict] = {}
|
||||
|
||||
_symbolic_versions = {}
|
||||
_symbolic_versions: Dict[Union[int, str], Any] = {}
|
||||
from torch.onnx.symbolic_helper import _onnx_stable_opsets
|
||||
for opset_version in _onnx_stable_opsets:
|
||||
module = importlib.import_module('torch.onnx.symbolic_opset{}'.format(opset_version))
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from torch._six import string_classes
|
|||
from torch.jit import _unique_state_dict
|
||||
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode
|
||||
from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _check_onnx_proto
|
||||
from typing import Union, Tuple, List
|
||||
|
||||
|
||||
# the flag to tell the user whether it's in the middle of ONNX export or not
|
||||
|
|
@ -76,7 +77,7 @@ def export(model, args, f, export_params=True, verbose=False, training=None,
|
|||
if aten or export_raw_ir:
|
||||
assert operator_export_type is None
|
||||
assert aten ^ export_raw_ir
|
||||
operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW
|
||||
operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW
|
||||
elif operator_export_type is None:
|
||||
if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
|
||||
operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
|
||||
|
|
@ -351,6 +352,7 @@ def _trace_and_get_graph_from_model(model, args):
|
|||
|
||||
def _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes):
|
||||
torch_out = None
|
||||
params: Union[List, Tuple]
|
||||
if isinstance(model, torch.jit.ScriptModule):
|
||||
try:
|
||||
graph = model.forward.graph
|
||||
|
|
@ -442,7 +444,7 @@ def _model_to_graph(model, args, verbose=False,
|
|||
param_names = input_and_param_names[len(input_and_param_names) - len(params):]
|
||||
params_dict = dict(zip(param_names, params))
|
||||
|
||||
if training is None or training == TrainingMode.EVAL or (training == TrainingMode.PRESERVE and not is_originally_training):
|
||||
if training is None or training == TrainingMode.EVAL:
|
||||
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
|
||||
|
||||
if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
|
||||
|
|
@ -476,7 +478,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
|
|||
if aten or export_raw_ir:
|
||||
assert operator_export_type is None
|
||||
assert aten ^ export_raw_ir
|
||||
operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW
|
||||
operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW
|
||||
elif operator_export_type is None:
|
||||
operator_export_type = OperatorExportTypes.ONNX
|
||||
return _export_to_pretty_string(model, args, f, export_params, verbose, training,
|
||||
|
|
@ -1051,6 +1053,10 @@ def _graph_constant(g, value, dims, type, *args, **kwargs):
|
|||
dims = [1]
|
||||
isscalar = True
|
||||
type = type.lower()
|
||||
tensor: Union[torch.CharTensor, torch.ShortTensor,
|
||||
torch.IntTensor, torch.LongTensor,
|
||||
torch.HalfTensor, torch.FloatTensor,
|
||||
torch.DoubleTensor]
|
||||
if type == "char":
|
||||
tensor = torch.CharTensor(*dims)
|
||||
elif type == "short":
|
||||
|
|
@ -1068,7 +1074,7 @@ def _graph_constant(g, value, dims, type, *args, **kwargs):
|
|||
else:
|
||||
raise ValueError("Unknown type, type should be one of the following strings: "
|
||||
"char, short, int, long, half, float, double")
|
||||
tensor.fill_(value)
|
||||
tensor.fill_(value) # type: ignore
|
||||
if isscalar:
|
||||
return g.op("Constant", *args, value_z=tensor, **kwargs)
|
||||
return g.op("Constant", *args, value_t=tensor, **kwargs)
|
||||
|
|
@ -1141,8 +1147,8 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
|
|||
dynamic_axes[key] = value_dict
|
||||
|
||||
|
||||
torch._C.Graph.op = _graph_op
|
||||
torch._C.Graph.at = _graph_at
|
||||
torch._C.Block.op = _block_op
|
||||
torch._C.Graph.constant = _graph_constant
|
||||
torch._C.Node.__getitem__ = _node_getitem
|
||||
torch._C.Graph.op = _graph_op # type: ignore
|
||||
torch._C.Graph.at = _graph_at # type: ignore
|
||||
torch._C.Block.op = _block_op # type: ignore
|
||||
torch._C.Graph.constant = _graph_constant # type: ignore
|
||||
torch._C.Node.__getitem__ = _node_getitem # type: ignore
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user