mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
28 lines
662 B
Python
28 lines
662 B
Python
r"""This file provides a location for operators that help exporting
|
|
models via onnx. E.g. shape_as_tensor and reshape_from_tensor_shape
|
|
are to make all dynamic sizes operations traceble.
|
|
|
|
"""
|
|
|
|
import torch
|
|
import torch.onnx
|
|
import torch.onnx.utils
|
|
|
|
|
|
def _shape_as_tensor(g, input):
|
|
return g.op('Shape', input)
|
|
|
|
|
|
@torch.onnx.symbolic_override(_shape_as_tensor)
|
|
def shape_as_tensor(x):
|
|
return torch.LongTensor(tuple(x.shape))
|
|
|
|
|
|
def _reshape_from_tensor_shape(g, input, shape):
|
|
return g.op('Reshape', input, shape)
|
|
|
|
|
|
@torch.onnx.symbolic_override(_reshape_from_tensor_shape)
|
|
def reshape_from_tensor_shape(x, shape):
|
|
return x.reshape(shape.tolist())
|