mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Remove black_listed_operators for opset 12 as we now support these ops. Pull Request resolved: https://github.com/pytorch/pytorch/pull/39414 Reviewed By: hl475 Differential Revision: D21915584 Pulled By: houseroad fbshipit-source-id: 37ec7bdd2b5a845484535054026d6613d0921b7a
102 lines
3.7 KiB
Python
102 lines
3.7 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import torch
|
|
import torch.onnx.symbolic_helper as sym_help
|
|
from torch.onnx.symbolic_helper import parse_args, _parse_arg
|
|
|
|
|
|
# EDITING THIS FILE? READ THIS FIRST!
|
|
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
|
|
|
# This file exports ONNX ops for opset 12
|
|
|
|
@parse_args('s', 'v')
|
|
def einsum(g, equation, tensor_list):
|
|
tensors = sym_help._unpack_list(tensor_list)
|
|
return g.op("Einsum", *tensors, equation_s=equation)
|
|
|
|
|
|
@parse_args('v', 'f', 'i')
|
|
def dropout(g, input, p, train):
|
|
sym_help.assert_training_mode(train, "dropout")
|
|
# in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
|
|
if not sym_help._training_mode:
|
|
return input
|
|
|
|
p = g.op("Constant", value_t=torch.tensor(p))
|
|
t = g.op("Constant", value_t=torch.tensor(True))
|
|
r, _ = g.op("Dropout", input, p, t, outputs=2)
|
|
return r
|
|
|
|
|
|
def nll_loss(g, self, target, weight, reduction, ignore_index):
|
|
# none reduction : onnx::Constant[value={0}]
|
|
# mean reduction : onnx::Constant[value={1}]
|
|
# sum reduction : onnx::Constant[value={2}]
|
|
reduction = sym_help._maybe_get_const(reduction, 'i')
|
|
reduction_vals = ['none', 'mean', 'sum']
|
|
reduction = reduction_vals[reduction]
|
|
|
|
# when ignore_index is not specified, ignore_index == onnx::Constant[value={-100}]
|
|
ignore_index = sym_help._maybe_get_const(ignore_index, 'i')
|
|
if ignore_index == -100:
|
|
if weight.node().mustBeNone():
|
|
return g.op("NegativeLogLikelihoodLoss", self, target, reduction_s=reduction)
|
|
else:
|
|
return g.op("NegativeLogLikelihoodLoss", self, target, weight, reduction_s=reduction)
|
|
|
|
# if ignore_index is specified, compute nllloss with no reduction and apply the reduction afterwards
|
|
if weight.node().mustBeNone():
|
|
nllloss = g.op("NegativeLogLikelihoodLoss", self, target, reduction_s=reduction, ignore_index_i=ignore_index)
|
|
else:
|
|
nllloss = g.op("NegativeLogLikelihoodLoss", self, target, weight, reduction_s=reduction, ignore_index_i=ignore_index)
|
|
|
|
return nllloss
|
|
|
|
|
|
def nll_loss2d(g, self, target, weight, reduction, ignore_index):
|
|
return nll_loss(g, self, target, weight, reduction, ignore_index)
|
|
|
|
|
|
def celu(g, self, alpha):
|
|
alpha = sym_help._maybe_get_const(alpha, 'f')
|
|
# if the input is of type double cast it to float
|
|
if self.type().scalarType() == 'Double':
|
|
self = g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx['Float'])
|
|
out = g.op("Celu", self, alpha_f=alpha)
|
|
return g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx['Double'])
|
|
|
|
return g.op("Celu", self, alpha_f=alpha)
|
|
|
|
|
|
def argmax(g, input, dim, keepdim):
|
|
if sym_help._is_none(dim):
|
|
from torch.onnx.symbolic_opset9 import reshape
|
|
flattened = reshape(g, input, (-1,))
|
|
return g.op('ArgMax', flattened, axis_i=0, keepdims_i=False, select_last_index_i=True)
|
|
else:
|
|
dim = _parse_arg(dim, 'i')
|
|
keepdim = _parse_arg(keepdim, 'i')
|
|
return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim, select_last_index_i=True)
|
|
|
|
|
|
def argmin(g, input, dim, keepdim):
|
|
if sym_help._is_none(dim):
|
|
from torch.onnx.symbolic_opset9 import reshape
|
|
flattened = reshape(g, input, (-1,))
|
|
return g.op('ArgMin', flattened, axis_i=0, keepdims_i=False, select_last_index_i=True)
|
|
else:
|
|
dim = _parse_arg(dim, 'i')
|
|
keepdim = _parse_arg(keepdim, 'i')
|
|
return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim, select_last_index_i=True)
|
|
|
|
|
|
def pow(g, self, exponent):
|
|
return g.op("Pow", self, exponent)
|
|
|
|
def ge(g, input, other):
|
|
return g.op('GreaterOrEqual', input, other)
|
|
|
|
def le(g, input, other):
|
|
return g.op('LessOrEqual', input, other)
|