mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Add symbolic for torch.addcmul (#72126)
* Add addcmul op * Remove required_grad Pull Request resolved: https://github.com/pytorch/pytorch/pull/73101
This commit is contained in:
parent
28bf2f80cf
commit
cc2aad2ef2
|
|
@ -6308,6 +6308,16 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
x = torch.randn(4, 2, 3, requires_grad=True)
|
||||
self.run_test(InplaceAddModel(), x)
|
||||
|
||||
def test_addcmul(self):
|
||||
class AddcmulModel(torch.nn.Module):
|
||||
def forward(self, x, t1, t2):
|
||||
return torch.addcmul(x, t1, t2), torch.addcmul(x, t1, t2, value=2.2)
|
||||
|
||||
x = torch.randn(1, 3)
|
||||
t1 = torch.randn(3, 1)
|
||||
t2 = torch.randn(1, 3)
|
||||
self.run_test(AddcmulModel(), (x, t1, t2))
|
||||
|
||||
def test_rsqrt(self):
|
||||
class RsqrtModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
|
|||
|
|
@ -115,6 +115,12 @@ def div(g, self, other, *args):
|
|||
return _div_rounding_mode(g, self, other, *args)
|
||||
|
||||
|
||||
@parse_args("v", "v", "v", "f")
|
||||
def addcmul(g, self, tensor1, tensor2, value=1.0):
|
||||
value_tens = g.op("Constant", value_t=torch.tensor([value]))
|
||||
return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens))
|
||||
|
||||
|
||||
@parse_args("v", "v", "s")
|
||||
def _div_rounding_mode(g, self, other, rounding_mode):
|
||||
if rounding_mode is None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user