pytorch/torch/onnx/operators.py
James Reed 4667983f0f
Fixes for interpreter and ONNX export for translation (#7044)
Fixes for interpreter and ONNX export for translation

Address comments
2018-04-27 22:23:57 -07:00

28 lines
662 B
Python

r"""This file provides a location for operators that help exporting
models via onnx. E.g. shape_as_tensor and reshape_from_tensor_shape
are to make all dynamic sizes operations traceble.
"""
import torch
import torch.onnx
import torch.onnx.utils
def _shape_as_tensor(g, input):
return g.op('Shape', input)
@torch.onnx.symbolic_override(_shape_as_tensor)
def shape_as_tensor(x):
return torch.LongTensor(tuple(x.shape))
def _reshape_from_tensor_shape(g, input, shape):
return g.op('Reshape', input, shape)
@torch.onnx.symbolic_override(_reshape_from_tensor_shape)
def reshape_from_tensor_shape(x, shape):
return x.reshape(shape.tolist())