mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Allow creating SugaredValue for a complex valued IValue and deserialization logic for "infj" and "nanj" global constants (#54328)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54328 Test Plan: Imported from OSS Reviewed By: nikithamalgifb Differential Revision: D27369134 Pulled By: anjali411 fbshipit-source-id: aec26750a6fc8917ee15306684b743d13a91570c
This commit is contained in:
parent
f4dfa02c03
commit
1bccd48465
|
|
@ -40,6 +40,7 @@ class TestComplex(JitTestCase):
|
|||
self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j]
|
||||
self.c = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j}
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, b: int):
|
||||
return b + 2j
|
||||
|
||||
|
|
@ -47,6 +48,7 @@ class TestComplex(JitTestCase):
|
|||
self.assertEqual(loaded.a, 3 + 5j)
|
||||
self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4])
|
||||
self.assertEqual(loaded.c, {2 + 3j : 2 - 3j, -4.3 - 2j: 3j})
|
||||
self.assertEqual(loaded(2), 2 + 2j)
|
||||
|
||||
def test_complex_parse(self):
|
||||
def fn(a: int, b: torch.Tensor, dim: int):
|
||||
|
|
@ -59,18 +61,18 @@ class TestComplex(JitTestCase):
|
|||
|
||||
self.checkScript(fn, (t1, t2, 2))
|
||||
|
||||
def test_complex_math_ops(self):
|
||||
def test_complex_constants_and_ops(self):
|
||||
vals = ([0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2]
|
||||
+ [10.0 ** i for i in range(2)] + [-(10.0 ** i) for i in range(2)])
|
||||
complex_vals = tuple((x + y * 1j) for x, y in product(vals, vals))
|
||||
|
||||
def checkMath(func_name):
|
||||
funcs_template = dedent('''
|
||||
funcs_template = dedent('''
|
||||
def func(a: complex):
|
||||
return cmath.{func}(a)
|
||||
return cmath.{func_or_const}(a)
|
||||
''')
|
||||
|
||||
funcs_str = funcs_template.format(func=func_name)
|
||||
def checkCmath(func_name, funcs_template=funcs_template):
|
||||
funcs_str = funcs_template.format(func_or_const=func_name)
|
||||
scope = {}
|
||||
execWrapper(funcs_str, globals(), scope)
|
||||
cu = torch.jit.CompilationUnit(funcs_str)
|
||||
|
|
@ -102,10 +104,37 @@ class TestComplex(JitTestCase):
|
|||
|
||||
# --- Unary ops ---
|
||||
for op in unary_ops:
|
||||
checkMath(op)
|
||||
checkCmath(op)
|
||||
|
||||
def fn(x: complex):
|
||||
return abs(x)
|
||||
|
||||
for val in complex_vals:
|
||||
self.checkScript(fn, (val, ))
|
||||
|
||||
func_constants_template = dedent('''
|
||||
def func():
|
||||
return cmath.{func_or_const}
|
||||
''')
|
||||
float_consts = ['pi', 'e', 'tau', 'inf', 'nan']
|
||||
complex_consts = ['infj', 'nanj']
|
||||
for x in (float_consts + complex_consts):
|
||||
checkCmath(x, funcs_template=func_constants_template)
|
||||
|
||||
|
||||
def test_infj_nanj_pickle(self):
|
||||
class ComplexModule(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = 3 + 5j
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, infj: int, nanj: int):
|
||||
if infj == 2:
|
||||
return infj + cmath.infj
|
||||
else:
|
||||
return nanj + cmath.nanj
|
||||
|
||||
loaded = self.getExportImportCopy(ComplexModule())
|
||||
self.assertEqual(loaded(2, 3), 2 + cmath.infj)
|
||||
self.assertEqual(loaded(3, 4), 4 + cmath.nanj)
|
||||
|
|
|
|||
|
|
@ -1061,6 +1061,10 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
|||
return toSimple(g.insertConstant(py::cast<int64_t>(obj), loc));
|
||||
} else if (py::isinstance<py::float_>(obj)) {
|
||||
return toSimple(g.insertConstant(py::cast<double>(obj), loc));
|
||||
} else if (PyComplex_CheckExact(obj.ptr())) {
|
||||
auto c_obj = py::cast<std::complex<double>>(obj.ptr());
|
||||
return toSimple(
|
||||
g.insertConstant(static_cast<c10::complex<double>>(c_obj), loc));
|
||||
} else if (py::isinstance<py::str>(obj)) {
|
||||
return toSimple(g.insertConstant(py::cast<std::string>(obj), loc));
|
||||
} else if (obj.is(py::none())) {
|
||||
|
|
|
|||
|
|
@ -222,6 +222,16 @@ struct SourceImporterImpl : public Resolver,
|
|||
return std::make_shared<SimpleValue>(
|
||||
graph->insertConstant(std::numeric_limits<double>::quiet_NaN(), loc));
|
||||
}
|
||||
if (name == "infj") {
|
||||
return std::make_shared<SimpleValue>(graph->insertConstant(
|
||||
c10::complex<double>(0, std::numeric_limits<double>::infinity()),
|
||||
loc));
|
||||
}
|
||||
if (name == "nanj") {
|
||||
return std::make_shared<SimpleValue>(graph->insertConstant(
|
||||
c10::complex<double>(0, std::numeric_limits<double>::quiet_NaN()),
|
||||
loc));
|
||||
}
|
||||
if (name == "__torch__") {
|
||||
return std::make_shared<ClassNamespaceValue>(
|
||||
c10::QualifiedName(name), shared_from_this());
|
||||
|
|
|
|||
|
|
@ -45,6 +45,8 @@ const static std::unordered_set<std::string> reserved_names = {
|
|||
"getattr",
|
||||
"inf",
|
||||
"nan",
|
||||
"infj",
|
||||
"nanj",
|
||||
"ops",
|
||||
"__torch__",
|
||||
// the python keywords
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user