mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[complex32] support printing the tensor
Reference: #74537 Pull Request resolved: https://github.com/pytorch/pytorch/pull/76614 Approved by: https://github.com/anjali411
This commit is contained in:
parent
fb24614011
commit
e36d25fbae
|
|
@ -6581,6 +6581,11 @@ class TestTorch(TestCase):
|
|||
self.assertEqual(x.__repr__(), str(x))
|
||||
self.assertExpectedInline(str(x), '''tensor([2.3000+4.j, 7.0000+6.j])''')
|
||||
|
||||
# test complex half tensor
|
||||
x = torch.tensor([1.25 + 4j, -7. + 6j], dtype=torch.chalf)
|
||||
self.assertEqual(x.__repr__(), str(x))
|
||||
self.assertExpectedInline(str(x), '''tensor([ 1.2500+4.j, -7.0000+6.j], dtype=torch.complex32)''')
|
||||
|
||||
# test scientific notation for complex tensors
|
||||
x = torch.tensor([1e28 + 2j , -1e-28j])
|
||||
self.assertEqual(x.__repr__(), str(x))
|
||||
|
|
|
|||
|
|
@ -254,6 +254,9 @@ def _tensor_str(self, indent):
|
|||
if self.dtype is torch.float16 or self.dtype is torch.bfloat16:
|
||||
self = self.float()
|
||||
|
||||
if self.dtype is torch.complex32:
|
||||
self = self.cfloat()
|
||||
|
||||
if self.dtype.is_complex:
|
||||
# handle the conjugate bit
|
||||
self = self.resolve_conj()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user