# Owner(s): ["module: unknown"] import torch import copy from torch.testing._internal.common_utils import TestCase, run_tests class TestPerOverloadAPI(TestCase): def test_basics_opoverloadpacket(self): # add is ony used as an example here. It is ok to update the test # if the semantics of add are modified in the future. add_packet = torch.ops.aten.add # class attributes self.assertEqual(add_packet.__name__, 'add') self.assertEqual(str(add_packet), 'aten.add') # callable self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5)) # correct module self.assertEqual(add_packet.__module__, add_packet.op.__module__) # caching another_add_packet = torch.ops.aten.add self.assertEqual(id(add_packet), id(another_add_packet)) # deepcopy is a no-op self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet))) # pretty print self.assertEqual(repr(add_packet), "") self.assertRaises(AttributeError, lambda: add_packet.foo) def test_basics_opoverload(self): add_packet = torch.ops.aten.add add_tensoroverload = add_packet.Tensor # class attributes self.assertEqual(str(add_tensoroverload), 'aten.add.Tensor') self.assertEqual(add_tensoroverload.__name__, 'add.Tensor') self.assertEqual(add_tensoroverload.overloadpacket, add_packet) # deepcopy is a no-op self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload))) # caching another_add_tensoroverload = torch.ops.aten.add.Tensor self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload)) # pretty print self.assertEqual(repr(add_tensoroverload), "") # callable self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5)) a = torch.tensor(2) b = torch.tensor(0) torch.ops.aten.add.out(a, a, out=b) self.assertEqual(b, torch.tensor(4)) self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b)) def test_decompose(self): x = torch.randn(2, 3) y = torch.randn(5, 3) self.assertEqual( torch.ops.aten.linear.default.decompose(x, y), torch.ops.aten.linear.default(x, y) ) if __name__ == '__main__': run_tests()