[complex32] support printing the tensor

Reference: #74537
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76614
Approved by: https://github.com/anjali411
This commit is contained in:
kshitij12345 2022-05-01 12:46:09 +00:00 committed by PyTorch MergeBot
parent fb24614011
commit e36d25fbae
2 changed files with 8 additions and 0 deletions

View File

@ -6581,6 +6581,11 @@ class TestTorch(TestCase):
self.assertEqual(x.__repr__(), str(x))
self.assertExpectedInline(str(x), '''tensor([2.3000+4.j, 7.0000+6.j])''')
# test complex half tensor
x = torch.tensor([1.25 + 4j, -7. + 6j], dtype=torch.chalf)
self.assertEqual(x.__repr__(), str(x))
self.assertExpectedInline(str(x), '''tensor([ 1.2500+4.j, -7.0000+6.j], dtype=torch.complex32)''')
# test scientific notation for complex tensors
x = torch.tensor([1e28 + 2j , -1e-28j])
self.assertEqual(x.__repr__(), str(x))

View File

@ -254,6 +254,9 @@ def _tensor_str(self, indent):
if self.dtype is torch.float16 or self.dtype is torch.bfloat16:
self = self.float()
if self.dtype is torch.complex32:
self = self.cfloat()
if self.dtype.is_complex:
# handle the conjugate bit
self = self.resolve_conj()