diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py index 6ff1a8ab29d..1572433c5ae 100644 --- a/test/dynamo/cpython/3_13/test_complex.py +++ b/test/dynamo/cpython/3_13/test_complex.py @@ -1,16 +1,147 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_complex.py + +import sys +import torch +import torch._dynamo.test_case import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import ( + run_tests, + slowTest, + xfailIfTorchDynamo, +) + +__TestCase = CPythonTestCase + + +# redirect import statements import sys -from test import support -from test.support.testcase import ComplexesAreIdenticalMixin -from test.support.numbers import ( - VALID_UNDERSCORE_LITERALS, - INVALID_UNDERSCORE_LITERALS, +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", ) +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + +import unittest +import sys +from test import support +from test.support.testcase import ComplexesAreIdenticalMixin from random import random from math import isnan, copysign +import math import operator +VALID_UNDERSCORE_LITERALS = [ + '0_0_0', + '4_2', + '1_0000_0000', + '0b1001_0100', + '0xffff_ffff', + '0o5_7_7', + '1_00_00.5', + '1_00_00.5e5', + '1_00_00e5_1', + '1e1_0', + '.1_4', + '.1_4e1', + '0b_0', + '0x_f', + '0o_5', + '1_00_00j', + '1_00_00.5j', + '1_00_00e5_1j', + '.1_4j', + '(1_2.5+3_3j)', + '(.5_6j)', +] +INVALID_UNDERSCORE_LITERALS = [ + # Trailing underscores: + '0_', + '42_', + '1.4j_', + '0x_', + '0b1_', + '0xf_', + '0o5_', + '0 if 1_Else 1', + # Underscores in the base selector: + '0_b0', + '0_xf', + '0_o5', + # Old-style octal, still disallowed: + '0_7', + '09_99', + # Multiple consecutive underscores: + '4_______2', + '0.1__4', + '0.1__4j', + '0b1001__0100', + '0xffff__ffff', + '0x___', + '0o5__77', + '1e1__0', + '1e1__0j', + # Underscore right before a dot: + '1_.4', + '1_.4j', + # Underscore right after a dot: + '1._4', + '1._4j', + '._5', + '._5j', + # Underscore right after a sign: + '1.0e+_1', + '1.0e+_1j', + # Underscore right before j: + '1.4_j', + '1.4e5_j', + # Underscore right before e: + '1_e1', + '1.4_e1', + '1.4_e1j', + # Underscore right after e: + '1e_1', + '1.4e_1', + '1.4e_1j', + # Complex cases with parens: + '(1+1.5_j_)', + '(1+1.5_j)', +] + INF = float("inf") NAN = float("nan") DBL_MAX = sys.float_info.max @@ -45,7 +176,40 @@ class WithComplex: def __complex__(self): return self.value -class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): +class ComplexTest(__TestCase): + + def assertFloatIdentical(self, x, y): + """Fail unless floats x and y are identical, in the sense that: + (1) both x and y are nans, or + (2) both x and y are infinities, with the same sign, or + (3) both x and y are zeros, with the same sign, or + (4) x and y are both finite and nonzero, and x == y + + """ + msg = 'floats {!r} and {!r} are not identical' + + if math.isnan(x) or math.isnan(y): + if math.isnan(x) and math.isnan(y): + return + elif x == y: + if x != 0.0: + return + # both zero; check that signs match + elif math.copysign(1.0, x) == math.copysign(1.0, y): + return + else: + msg += ': zeros have different signs' + self.fail(msg.format(x, y)) + + def assertComplexesAreIdentical(self, x, y): + """Fail unless complex numbers x and y have equal values and signs. + + In particular, if x and y both have real (or imaginary) part + zero, but the zeros have different signs, this test will fail. + + """ + self.assertFloatIdentical(x.real, y.real) + self.assertFloatIdentical(x.imag, y.imag) def assertAlmostEqual(self, a, b): if isinstance(a, complex): @@ -74,6 +238,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): # check that relative difference < eps self.assertTrue(abs((x-y)/y) < eps) + def assertFloatsAreIdentical(self, x, y): + """assert that floats x and y are identical, in the sense that: + (1) both x and y are nans, or + (2) both x and y are infinities, with the same sign, or + (3) both x and y are zeros, with the same sign, or + (4) x and y are both finite and nonzero, and x == y + + """ + msg = 'floats {!r} and {!r} are not identical' + + if isnan(x) or isnan(y): + if isnan(x) and isnan(y): + return + elif x == y: + if x != 0.0: + return + # both zero; check that signs match + elif copysign(1.0, x) == copysign(1.0, y): + return + else: + msg += ': zeros have different signs' + self.fail(msg.format(x, y)) + def assertClose(self, x, y, eps=1e-9): """Return true iff complexes x and y "are close".""" self.assertCloseAbs(x.real, y.real, eps) @@ -93,6 +280,7 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): q = z.__truediv__(y) self.assertClose(q, x) + @slowTest def test_truediv(self): simple_real = [float(i) for i in range(-5, 6)] simple_complex = [complex(x, y) for x in simple_real for y in simple_real] @@ -338,7 +526,10 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): def test_boolcontext(self): for i in range(100): - self.assertTrue(complex(random() + 1e-6, random() + 1e-6)) + with torch._dynamo.error_on_graph_break(False): + r1 = random() + r2 = random() + self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6)) self.assertTrue(not complex(0.0, 0.0)) self.assertTrue(1j) @@ -431,12 +622,13 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): self.assertRaises(TypeError, complex, WithComplex(1), object()) self.assertRaises(TypeError, complex, WithComplex(None), object()) - class EvilExc(Exception): - pass + with torch._dynamo.error_on_graph_break(False): + class EvilExc(Exception): + pass - class evilcomplex: - def __complex__(self): - raise EvilExc + class evilcomplex: + def __complex__(self): + raise EvilExc self.assertRaises(EvilExc, complex, evilcomplex()) @@ -460,31 +652,33 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): self.assertRaises(TypeError, complex, WithIndex(None), 1.5) self.assertRaises(TypeError, complex, 1.5, WithIndex(None)) - class MyInt: - def __int__(self): - return 42 + with torch._dynamo.error_on_graph_break(False): + class MyInt: + def __int__(self): + return 42 self.assertRaises(TypeError, complex, MyInt()) self.assertRaises(TypeError, complex, MyInt(), 1.5) self.assertRaises(TypeError, complex, 1.5, MyInt()) - class complex0(complex): - """Test usage of __complex__() when inheriting from 'complex'""" - def __complex__(self): - return 42j - - class complex1(complex): - """Test usage of __complex__() with a __new__() method""" - def __new__(self, value=0j): - return complex.__new__(self, 2*value) - def __complex__(self): - return self - - class complex2(complex): - """Make sure that __complex__() calls fail if anything other than a - complex is returned""" - def __complex__(self): - return None + with torch._dynamo.error_on_graph_break(False): + class complex0(complex): + """Test usage of __complex__() when inheriting from 'complex'""" + def __complex__(self): + return 42j + + class complex1(complex): + """Test usage of __complex__() with a __new__() method""" + def __new__(self, value=0j): + return complex.__new__(self, 2*value) + def __complex__(self): + return self + + class complex2(complex): + """Make sure that __complex__() calls fail if anything other than a + complex is returned""" + def __complex__(self): + return None check(complex(complex0(1j)), 0.0, 42.0) with self.assertWarns(DeprecationWarning): @@ -855,4 +1049,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): if __name__ == "__main__": - unittest.main() + run_tests()