[ONNX] Support gelu for fp16 export (#50487) (#50911)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50911

Need to replace dtype of export created scalars from float to double. (In torch implicit conversion logic, python numbers are double)

Test case skipped in CI due to that current CI job env does not have CUDA support.

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D26050889

Pulled By: SplitInfinity

fbshipit-source-id: 1fdde23a68d4793e6b9a82840acc213e5c3aa760
This commit is contained in:
BowenBao 2021-01-27 17:41:50 -08:00 committed by Facebook GitHub Bot
parent 70dcfe2991
commit 68034197e8
3 changed files with 35 additions and 4 deletions

View File

@ -58,6 +58,7 @@ pytest "${args[@]}" \
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_caffe2.py" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_shape_inference.py" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py" \
"${test_paths[@]}"
# onnxruntime only support py3

View File

@ -0,0 +1,31 @@
import unittest
import onnxruntime # noqa
import torch
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
from test_pytorch_common import skipIfNoCuda
from test_pytorch_onnx_onnxruntime import TestONNXRuntime
class TestONNXRuntime_cuda(unittest.TestCase):
from torch.onnx.symbolic_helper import _export_onnx_opset_version
opset_version = _export_onnx_opset_version
keep_initializers_as_inputs = True
use_new_jit_passes = True
onnx_shape_inference = True
@skipIfUnsupportedMinOpsetVersion(9)
@skipIfNoCuda
def test_gelu_fp16(self):
class GeluModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.gelu(x)
x = torch.randn(2, 4, 5, 6, requires_grad=True, dtype=torch.float16, device=torch.device('cuda'))
self.run_test(GeluModel(), x, rtol=1e-3, atol=1e-5)
TestONNXRuntime_cuda.setUp = TestONNXRuntime.setUp
TestONNXRuntime_cuda.run_test = TestONNXRuntime.run_test
if __name__ == '__main__':
unittest.main(TestONNXRuntime_cuda())

View File

@ -2702,10 +2702,9 @@ def remainder(g, input, other):
def gelu(g, self):
_sqrt2 = 1.4142135623730951
erf = g.op('Erf', g.op('Div', self, torch.tensor(_sqrt2)))
erf_plusone = add(g, erf, g.op('Constant', value_t=torch.tensor(1, dtype=torch.float)))
return mul(g, mul(g, self, erf_plusone), g.op('Constant', value_t=torch.tensor(0.5, dtype=torch.float)))
erf = g.op('Erf', g.op('Div', self, torch.tensor(_sqrt2, dtype=torch.double)))
erf_plusone = add(g, erf, g.op('Constant', value_t=torch.tensor(1, dtype=torch.double)))
return mul(g, mul(g, self, erf_plusone), g.op('Constant', value_t=torch.tensor(0.5, dtype=torch.double)))
@parse_args('v', 'i', 'v', 'v', 'f', 'i')
def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):