from __future__ import absolute_import, division, print_function, unicode_literals import torch from torch._C import ListType import warnings import torch.onnx # This import monkey-patches graph manipulation methods on Graph, used for the # ONNX symbolics import torch.onnx.utils from functools import wraps # Note [Edit Symbolic Files] # EDITING THIS FILE AND SYMBOLIC_OPSET FILES? READ THIS FIRST! # # - These files is ONLY for ATen operators (e.g., operators that show up in the # trace as aten::blah). If you need to special case a primitive operator, # look at _run_symbolic_function # - Parameter ordering does NOT necessarily match what is in VariableType.cpp; # tensors are always first, then non-tensor arguments. # - Parameter names must *exactly* match the names in VariableType.cpp, because # dispatch is done with keyword arguments. # - Looking for inplace ops? They're detected by the trailing underscore, and # transparently dispatched to their non inplace versions in # 'run_symbolic_function'. See Note [Export inplace] # # ---------------------------------------------------------------------------------- # A note on Tensor types # ---------------------------------------------------------------------------------- # # In general, we should avoid depending on the type of Tensor Values contained # within the trace graph. However, this is sometimes unavoidable (due to ONNX # spec requirements, etc). If you are implementing a symbolic and need Tensor # type information, note that there are several levels of Tensor types, defined # in aten/src/ATen/core/jit_type.h: # # TensorType - This is a Tensor, but we don't know anything about its # properties (e.g. scalar type, # dims, shapes). # Appears as `Tensor` in graph print-outs. # ProfiledTensorType <: TensorType - Denotes a Tensor for which we know the # concrete sizes in addition to the information # contained in TensorTyper. This adds a sizes() # method which can be used to retrieve the # concrete sizes. # @deprecated # DimensionedTensorType <: TensorType - Denotes a Tensor for which we know the scalar # type and number of dimensions, but not the concrete # shapes. For example, appears as 'Float(*, *)' in # graph print-outs. Useful accessor methods include # dim() and scalarType() # @deprecated # CompleteTensorType <: DimensionedTensorType - Denotes a Tensor for which we know the # concrete sizes in addition to the information # contained in TensorTyper. This adds a sizes() # method which can be used to retrieve the # concrete sizes. # # In general, we should prefer to rely on the least specific information possible. # For example, not relying on tensor properties at all is better than relying # on the number of dimensions (DimensionedTensorType) which is better than relying on # concrete shapes (CompleteTensorType). Doing so will make the export symbolics # more robust to different graphs. # --------------------------------------------------------------------------------- # Helper functions # --------------------------------------------------------------------------------- # Save some builtins as locals, because we'll shadown them below _sum = sum def _parse_arg(value, desc): if desc == 'none': return value if desc == 'v' or not _is_value(value): return value if value.node().kind() == 'onnx::Constant': tval = value.node()['value'] if desc == 'i': return int(tval) elif desc == 'f': return float(tval) elif desc == 'b': return bool(tval) elif desc == 't': return tval elif desc == 'is': return [int(v) for v in tval] else: raise RuntimeError("ONNX symbolic doesn't know to interpret Constant node") elif value.node().kind() == 'prim::ListConstruct': if desc == 'is': for v in value.node().inputs(): if v.node().kind() != 'onnx::Constant': raise RuntimeError("Failed to export an ONNX attribute, " "since it's not constant, please try to make " "things (e.g., kernel size) static if possible") return [int(v.node()['value']) for v in value.node().inputs()] else: raise RuntimeError("ONNX symbolic doesn't know to interpret ListConstruct node") raise RuntimeError("Unexpected node type: {}".format(value.node().kind())) def _maybe_get_const(value, desc): if _is_value(value) and value.node().kind() == 'onnx::Constant': return _parse_arg(value, desc) return value def _maybe_get_scalar(value): value_t = _maybe_get_const(value, 't') if isinstance(value_t, torch.Tensor) and value_t.shape == (): return value_t return value def _get_const(value, desc, arg_name): if _is_value(value) and value.node().kind() != 'onnx::Constant': raise RuntimeError("ONNX symbolic expected a constant value of the {} argument, got `{}`".format(arg_name, value)) return _parse_arg(value, desc) def _unpack_list(list_value): list_node = list_value.node() assert list_node.kind() == "prim::ListConstruct" return list(list_node.inputs()) # Check if list_value is output from prim::ListConstruct # This is usually called before _unpack_list to ensure the list can be unpacked. def _is_packed_list(list_value): return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct" def parse_args(*arg_descriptors): def decorator(fn): fn._arg_descriptors = arg_descriptors def wrapper(g, *args): # some args may be optional, so the length may be smaller assert len(arg_descriptors) >= len(args) args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] return fn(g, *args) # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround try: wrapper = wraps(fn)(wrapper) except Exception: pass return wrapper return decorator def _scalar(x): """Convert a scalar tensor into a Python value.""" assert x.numel() == 1 return x.item() def _is_complete_or_dimensioned_tensor_type(tensor): return tensor.type().kind() == "DimensionedTensorType" or tensor.type().kind() == "CompleteTensorType" def _if_scalar_type_as(g, self, tensor): """ Convert self into the same type of tensor, as necessary. We only support implicit casting for scalars, so we never actually need to insert an ONNX cast operator here; just fix up the scalar. """ if isinstance(self, torch._C.Value): return self scalar_type = tensor.type().scalarType() if scalar_type: ty = scalar_type.lower() return getattr(self, ty)() return self def _is_value(x): return isinstance(x, torch._C.Value) def _is_tensor_list(x): return x.type().isSubtypeOf(ListType.ofTensors()) def _unimplemented(op, msg): warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported") def _black_list_in_opset(name): def symbolic_fn(*args, **kwargs): raise RuntimeError("ONNX export failed on {}, which is not implemented for opset {}. " "Try exporting with other opset versions." .format(name, _export_onnx_opset_version)) return symbolic_fn def _try_get_scalar_type(*args): for arg in args: try: return arg.type().scalarType() except RuntimeError: pass return None def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False): if _export_onnx_opset_version <= 9: from torch.onnx.symbolic_opset9 import _slice return _slice(g, input, axes, starts, ends) else: from torch.onnx.symbolic_opset10 import _slice return _slice(g, input, axes, starts, ends, steps, dynamic_slice) # --------------------------------------------------------------------- # ONNX operator version # --------------------------------------------------------------------- # READ ME BEFORE EDITING _default_onnx_opset_version: # # The variable below controls which ONNX operator set version we are # targeting. THIS VARIABLE HAS SEMANTIC EFFECT! Say a breaking # change occurred in version 8. As long as this variable < 8, you can # export models targeting the old behavior. However, if you bump # this variable to 8 or later, the breaking change will take into effect: # you MUST adjust any symbolic affected by breaking changes. The ONNX # spec publishes a *comprehensive* list of BC-breaking changes for every # operator revision at: # # https://github.com/onnx/onnx/blob/master/docs/Changelog.md # # Please be sure to go through and check all of our implementations here before # increasing this number. This includes symbolic definitions NOT in this # file, so grep for "OpName" (with quotes) # # Besides, opset_version can be specified in the invocation of export() # and export_to_pretty_string(), and _export_onnx_opset_version will be set # and the symbolic functions should check it to determine the behavior # of the exporter. _default_onnx_opset_version = 9 _onnx_master_opset = 10 _onnx_stable_opsets = [7, 8, 9, 10] _export_onnx_opset_version = _default_onnx_opset_version def _set_opset_version(opset_version): global _export_onnx_opset_version if opset_version == _default_onnx_opset_version: _export_onnx_opset_version = opset_version return if opset_version in _onnx_stable_opsets + [_onnx_master_opset]: _export_onnx_opset_version = opset_version return raise ValueError("Unsupported ONNX opset version: " + str(opset_version)) _operator_export_type = None def _set_operator_export_type(operator_export_type): global _operator_export_type _operator_export_type = operator_export_type # Metaprogram symbolics for each ATen native specialized cast operator. # For e.g. we specify a function named `_cast_uint8_t` that instantiates an # ONNX cast node with `to` attribute 'UINT8' # # TODO: remove these once we support Type's in the JIT IR and we can once again # use the unified toType operator cast_pytorch_to_onnx = { 'Byte': torch.onnx.TensorProtoDataType.UINT8, 'Char': torch.onnx.TensorProtoDataType.INT8, 'Double': torch.onnx.TensorProtoDataType.DOUBLE, 'Float': torch.onnx.TensorProtoDataType.FLOAT, 'Half': torch.onnx.TensorProtoDataType.FLOAT16, 'Int': torch.onnx.TensorProtoDataType.INT32, 'Long': torch.onnx.TensorProtoDataType.INT64, 'Short': torch.onnx.TensorProtoDataType.INT16, 'Bool': torch.onnx.TensorProtoDataType.BOOL, 'ComplexFloat': torch.onnx.TensorProtoDataType.COMPLEX64, 'ComplexDouble': torch.onnx.TensorProtoDataType.COMPLEX128, 'Undefined': torch.onnx.TensorProtoDataType.UNDEFINED, } scalar_name_to_pytorch = { 'uint8_t': 'Byte', 'int8_t': 'Char', 'double': 'Double', 'float': 'Float', 'half': 'Half', 'int': 'Int', 'int64_t': 'Long', 'int16_t': 'Short', 'bool': 'Bool', 'complex64': '', 'complex128': '' } # This indicates each scalar type's corresponding # torch type. Related source: # https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h scalar_type_to_pytorch_type = [ torch.uint8, # 0 torch.int8, # 1 torch.short, # 2 torch.int, # 3 torch.int64, # 4 torch.half, # 5 torch.float, # 6 torch.double, # 7 torch.complex64, # 9 torch.complex128, # 10 torch.bool, # 11 ] def _cast_func_template(to_i, g, input, non_blocking): return g.op("Cast", input, to_i=to_i) scalar_type_to_onnx = [ cast_pytorch_to_onnx["Byte"], cast_pytorch_to_onnx["Char"], cast_pytorch_to_onnx["Short"], cast_pytorch_to_onnx["Int"], cast_pytorch_to_onnx["Long"], cast_pytorch_to_onnx["Half"], cast_pytorch_to_onnx["Float"], cast_pytorch_to_onnx["Double"], cast_pytorch_to_onnx["Undefined"], cast_pytorch_to_onnx["ComplexFloat"], cast_pytorch_to_onnx["ComplexDouble"], cast_pytorch_to_onnx["Bool"], ]