mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[ONNX] Add symbolic for torch.addcmul (#72126)
* Add addcmul op * Remove required_grad Pull Request resolved: https://github.com/pytorch/pytorch/pull/73101
This commit is contained in:
parent
28bf2f80cf
commit
cc2aad2ef2
|
|
@ -6308,6 +6308,16 @@ class TestONNXRuntime(unittest.TestCase):
|
||||||
x = torch.randn(4, 2, 3, requires_grad=True)
|
x = torch.randn(4, 2, 3, requires_grad=True)
|
||||||
self.run_test(InplaceAddModel(), x)
|
self.run_test(InplaceAddModel(), x)
|
||||||
|
|
||||||
|
def test_addcmul(self):
|
||||||
|
class AddcmulModel(torch.nn.Module):
|
||||||
|
def forward(self, x, t1, t2):
|
||||||
|
return torch.addcmul(x, t1, t2), torch.addcmul(x, t1, t2, value=2.2)
|
||||||
|
|
||||||
|
x = torch.randn(1, 3)
|
||||||
|
t1 = torch.randn(3, 1)
|
||||||
|
t2 = torch.randn(1, 3)
|
||||||
|
self.run_test(AddcmulModel(), (x, t1, t2))
|
||||||
|
|
||||||
def test_rsqrt(self):
|
def test_rsqrt(self):
|
||||||
class RsqrtModel(torch.nn.Module):
|
class RsqrtModel(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
||||||
|
|
@ -115,6 +115,12 @@ def div(g, self, other, *args):
|
||||||
return _div_rounding_mode(g, self, other, *args)
|
return _div_rounding_mode(g, self, other, *args)
|
||||||
|
|
||||||
|
|
||||||
|
@parse_args("v", "v", "v", "f")
|
||||||
|
def addcmul(g, self, tensor1, tensor2, value=1.0):
|
||||||
|
value_tens = g.op("Constant", value_t=torch.tensor([value]))
|
||||||
|
return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens))
|
||||||
|
|
||||||
|
|
||||||
@parse_args("v", "v", "s")
|
@parse_args("v", "v", "s")
|
||||||
def _div_rounding_mode(g, self, other, rounding_mode):
|
def _div_rounding_mode(g, self, other, rounding_mode):
|
||||||
if rounding_mode is None:
|
if rounding_mode is None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user