pytorch/test/test_cpp_extensions.py
2018-02-01 16:19:03 -08:00

24 lines
614 B
Python

import torch
import torch_test_cpp_extensions as cpp_extension
import common
class TestCppExtension(common.TestCase):
def test_extension_function(self):
x = torch.randn(4, 4)
y = torch.randn(4, 4)
z = cpp_extension.sigmoid_add(x, y)
self.assertEqual(z, x.sigmoid() + y.sigmoid())
def test_extension_module(self):
mm = cpp_extension.MatrixMultiplier(4, 8)
weights = torch.rand(8, 4)
expected = mm.get().mm(weights)
result = mm.forward(weights)
self.assertEqual(expected, result)
if __name__ == '__main__':
common.run_tests()