From cc2aad2ef2f6da367d29c46fb0fe1d767daed9ad Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 22 Feb 2022 10:20:21 -0800 Subject: [PATCH] [ONNX] Add symbolic for torch.addcmul (#72126) * Add addcmul op * Remove required_grad Pull Request resolved: https://github.com/pytorch/pytorch/pull/73101 --- test/onnx/test_pytorch_onnx_onnxruntime.py | 10 ++++++++++ torch/onnx/symbolic_opset9.py | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 64b707c95d5..461123ce23d 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -6308,6 +6308,16 @@ class TestONNXRuntime(unittest.TestCase): x = torch.randn(4, 2, 3, requires_grad=True) self.run_test(InplaceAddModel(), x) + def test_addcmul(self): + class AddcmulModel(torch.nn.Module): + def forward(self, x, t1, t2): + return torch.addcmul(x, t1, t2), torch.addcmul(x, t1, t2, value=2.2) + + x = torch.randn(1, 3) + t1 = torch.randn(3, 1) + t2 = torch.randn(1, 3) + self.run_test(AddcmulModel(), (x, t1, t2)) + def test_rsqrt(self): class RsqrtModel(torch.nn.Module): def forward(self, x): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 8c710269699..6c05a8c56a4 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -115,6 +115,12 @@ def div(g, self, other, *args): return _div_rounding_mode(g, self, other, *args) +@parse_args("v", "v", "v", "f") +def addcmul(g, self, tensor1, tensor2, value=1.0): + value_tens = g.op("Constant", value_t=torch.tensor([value])) + return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens)) + + @parse_args("v", "v", "s") def _div_rounding_mode(g, self, other, rounding_mode): if rounding_mode is None: