mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52881 **This PR adds:** 1. logic to parse complex constants (complex literals of the form `bj`) 2. logic to parse complex lists 3. support for complex constructors: `complex(tensor/int/float/bool, tensor/int/float/bool)` 4. Limited operator support - `add`, `sub`, `mul`, `torch.tensor`, `torch.as_tensor` **Follow-up work:** 1. Add complex support for unary and other registered ops. 2. support complex constructor with string as input (this is supported in Python eager mode). 3. Test all emitXYZ for all XYZ in `ir_emitter.cpp` (currently only emitConst, emitValueToTensor are tested). e.g., test loops etc. 4. onnx doesn't support complex tensors, so we should error out with a clear and descriptive error message. Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D27245059 Pulled By: anjali411 fbshipit-source-id: af043b5159ae99a9cc8691b5a8401503fa8d6f05
58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
import torch
|
|
import os
|
|
import sys
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
from typing import List, Dict
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
|
|
class TestComplex(JitTestCase):
|
|
def test_script(self):
|
|
def fn(a: complex):
|
|
return a
|
|
|
|
self.checkScript(fn, (3 + 5j,))
|
|
|
|
def test_complexlist(self):
|
|
def fn(a: List[complex], idx: int):
|
|
return a[idx]
|
|
|
|
input = [1j, 2, 3 + 4j, -5, -7j]
|
|
self.checkScript(fn, (input, 2))
|
|
|
|
def test_complexdict(self):
|
|
def fn(a: Dict[complex, complex], key: complex) -> complex:
|
|
return a[key]
|
|
|
|
input = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j}
|
|
self.checkScript(fn, (input, -4.3 - 2j))
|
|
|
|
def test_pickle(self):
|
|
class ComplexModule(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.a = 3 + 5j
|
|
self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j]
|
|
self.c = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j}
|
|
|
|
def forward(self, b: int):
|
|
return b + 2j
|
|
|
|
loaded = self.getExportImportCopy(ComplexModule())
|
|
self.assertEqual(loaded.a, 3 + 5j)
|
|
self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4])
|
|
self.assertEqual(loaded.c, {2 + 3j : 2 - 3j, -4.3 - 2j: 3j})
|
|
|
|
def test_complex_parse(self):
|
|
def fn(a: int, b: torch.Tensor, dim: int):
|
|
# verifies `emitValueToTensor()` 's behavior
|
|
b[dim] = 2.4 + 0.5j
|
|
return (3 * 2j) + a + 5j - 7.4j - 4
|
|
|
|
t1 = torch.tensor(1)
|
|
t2 = torch.tensor([0.4, 1.4j, 2.35])
|
|
|
|
self.checkScript(fn, (t1, t2, 2))
|