mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50911 Need to replace dtype of export created scalars from float to double. (In torch implicit conversion logic, python numbers are double) Test case skipped in CI due to that current CI job env does not have CUDA support. Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D26050889 Pulled By: SplitInfinity fbshipit-source-id: 1fdde23a68d4793e6b9a82840acc213e5c3aa760
This commit is contained in:
parent
70dcfe2991
commit
68034197e8
|
|
@ -58,6 +58,7 @@ pytest "${args[@]}" \
|
|||
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
|
||||
--ignore "$top_dir/test/onnx/test_pytorch_onnx_caffe2.py" \
|
||||
--ignore "$top_dir/test/onnx/test_pytorch_onnx_shape_inference.py" \
|
||||
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py" \
|
||||
"${test_paths[@]}"
|
||||
|
||||
# onnxruntime only support py3
|
||||
|
|
|
|||
31
test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
Normal file
31
test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import unittest
|
||||
import onnxruntime # noqa
|
||||
import torch
|
||||
|
||||
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
|
||||
from test_pytorch_common import skipIfNoCuda
|
||||
|
||||
from test_pytorch_onnx_onnxruntime import TestONNXRuntime
|
||||
|
||||
class TestONNXRuntime_cuda(unittest.TestCase):
|
||||
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
||||
opset_version = _export_onnx_opset_version
|
||||
keep_initializers_as_inputs = True
|
||||
use_new_jit_passes = True
|
||||
onnx_shape_inference = True
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
@skipIfNoCuda
|
||||
def test_gelu_fp16(self):
|
||||
class GeluModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
x = torch.randn(2, 4, 5, 6, requires_grad=True, dtype=torch.float16, device=torch.device('cuda'))
|
||||
self.run_test(GeluModel(), x, rtol=1e-3, atol=1e-5)
|
||||
|
||||
TestONNXRuntime_cuda.setUp = TestONNXRuntime.setUp
|
||||
TestONNXRuntime_cuda.run_test = TestONNXRuntime.run_test
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(TestONNXRuntime_cuda())
|
||||
|
|
@ -2702,10 +2702,9 @@ def remainder(g, input, other):
|
|||
|
||||
def gelu(g, self):
|
||||
_sqrt2 = 1.4142135623730951
|
||||
erf = g.op('Erf', g.op('Div', self, torch.tensor(_sqrt2)))
|
||||
erf_plusone = add(g, erf, g.op('Constant', value_t=torch.tensor(1, dtype=torch.float)))
|
||||
return mul(g, mul(g, self, erf_plusone), g.op('Constant', value_t=torch.tensor(0.5, dtype=torch.float)))
|
||||
|
||||
erf = g.op('Erf', g.op('Div', self, torch.tensor(_sqrt2, dtype=torch.double)))
|
||||
erf_plusone = add(g, erf, g.op('Constant', value_t=torch.tensor(1, dtype=torch.double)))
|
||||
return mul(g, mul(g, self, erf_plusone), g.op('Constant', value_t=torch.tensor(0.5, dtype=torch.double)))
|
||||
|
||||
@parse_args('v', 'i', 'v', 'v', 'f', 'i')
|
||||
def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user