pytorch/test/dynamo/cpython/3_13/test_complex.diff
William Wen 8678d831c4 [dynamo] rename set_fullgraph to error_on_graph_break (#161739)
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph.

I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet).

 cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739
Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
2025-09-04 01:15:06 +00:00

329 lines
10 KiB
Diff

diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py
index 6ff1a8ab29d..1572433c5ae 100644
--- a/test/dynamo/cpython/3_13/test_complex.py
+++ b/test/dynamo/cpython/3_13/test_complex.py
@@ -1,16 +1,147 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_complex.py
+
+import sys
+import torch
+import torch._dynamo.test_case
import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import (
+ run_tests,
+ slowTest,
+ xfailIfTorchDynamo,
+)
+
+__TestCase = CPythonTestCase
+
+
+# redirect import statements
import sys
-from test import support
-from test.support.testcase import ComplexesAreIdenticalMixin
-from test.support.numbers import (
- VALID_UNDERSCORE_LITERALS,
- INVALID_UNDERSCORE_LITERALS,
+import importlib.abc
+
+redirect_imports = (
+ "test.mapping_tests",
+ "test.typinganndata",
+ "test.test_grammar",
+ "test.test_math",
+ "test.test_iter",
+ "test.typinganndata.ann_module",
)
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
+ if fullname in redirect_imports:
+ try:
+ # Attempt to import the standalone module
+ name = fullname.removeprefix("test.")
+ r = importlib.import_module(name)
+ # Redirect the module in sys.modules
+ sys.modules[fullname] = r
+ # Return a module spec from the found module
+ return importlib.util.find_spec(name)
+ except ImportError:
+ return None
+ return None
+
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
+
+import unittest
+import sys
+from test import support
+from test.support.testcase import ComplexesAreIdenticalMixin
from random import random
from math import isnan, copysign
+import math
import operator
+VALID_UNDERSCORE_LITERALS = [
+ '0_0_0',
+ '4_2',
+ '1_0000_0000',
+ '0b1001_0100',
+ '0xffff_ffff',
+ '0o5_7_7',
+ '1_00_00.5',
+ '1_00_00.5e5',
+ '1_00_00e5_1',
+ '1e1_0',
+ '.1_4',
+ '.1_4e1',
+ '0b_0',
+ '0x_f',
+ '0o_5',
+ '1_00_00j',
+ '1_00_00.5j',
+ '1_00_00e5_1j',
+ '.1_4j',
+ '(1_2.5+3_3j)',
+ '(.5_6j)',
+]
+INVALID_UNDERSCORE_LITERALS = [
+ # Trailing underscores:
+ '0_',
+ '42_',
+ '1.4j_',
+ '0x_',
+ '0b1_',
+ '0xf_',
+ '0o5_',
+ '0 if 1_Else 1',
+ # Underscores in the base selector:
+ '0_b0',
+ '0_xf',
+ '0_o5',
+ # Old-style octal, still disallowed:
+ '0_7',
+ '09_99',
+ # Multiple consecutive underscores:
+ '4_______2',
+ '0.1__4',
+ '0.1__4j',
+ '0b1001__0100',
+ '0xffff__ffff',
+ '0x___',
+ '0o5__77',
+ '1e1__0',
+ '1e1__0j',
+ # Underscore right before a dot:
+ '1_.4',
+ '1_.4j',
+ # Underscore right after a dot:
+ '1._4',
+ '1._4j',
+ '._5',
+ '._5j',
+ # Underscore right after a sign:
+ '1.0e+_1',
+ '1.0e+_1j',
+ # Underscore right before j:
+ '1.4_j',
+ '1.4e5_j',
+ # Underscore right before e:
+ '1_e1',
+ '1.4_e1',
+ '1.4_e1j',
+ # Underscore right after e:
+ '1e_1',
+ '1.4e_1',
+ '1.4e_1j',
+ # Complex cases with parens:
+ '(1+1.5_j_)',
+ '(1+1.5_j)',
+]
+
INF = float("inf")
NAN = float("nan")
DBL_MAX = sys.float_info.max
@@ -45,7 +176,40 @@ class WithComplex:
def __complex__(self):
return self.value
-class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
+class ComplexTest(__TestCase):
+
+ def assertFloatIdentical(self, x, y):
+ """Fail unless floats x and y are identical, in the sense that:
+ (1) both x and y are nans, or
+ (2) both x and y are infinities, with the same sign, or
+ (3) both x and y are zeros, with the same sign, or
+ (4) x and y are both finite and nonzero, and x == y
+
+ """
+ msg = 'floats {!r} and {!r} are not identical'
+
+ if math.isnan(x) or math.isnan(y):
+ if math.isnan(x) and math.isnan(y):
+ return
+ elif x == y:
+ if x != 0.0:
+ return
+ # both zero; check that signs match
+ elif math.copysign(1.0, x) == math.copysign(1.0, y):
+ return
+ else:
+ msg += ': zeros have different signs'
+ self.fail(msg.format(x, y))
+
+ def assertComplexesAreIdentical(self, x, y):
+ """Fail unless complex numbers x and y have equal values and signs.
+
+ In particular, if x and y both have real (or imaginary) part
+ zero, but the zeros have different signs, this test will fail.
+
+ """
+ self.assertFloatIdentical(x.real, y.real)
+ self.assertFloatIdentical(x.imag, y.imag)
def assertAlmostEqual(self, a, b):
if isinstance(a, complex):
@@ -74,6 +238,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
# check that relative difference < eps
self.assertTrue(abs((x-y)/y) < eps)
+ def assertFloatsAreIdentical(self, x, y):
+ """assert that floats x and y are identical, in the sense that:
+ (1) both x and y are nans, or
+ (2) both x and y are infinities, with the same sign, or
+ (3) both x and y are zeros, with the same sign, or
+ (4) x and y are both finite and nonzero, and x == y
+
+ """
+ msg = 'floats {!r} and {!r} are not identical'
+
+ if isnan(x) or isnan(y):
+ if isnan(x) and isnan(y):
+ return
+ elif x == y:
+ if x != 0.0:
+ return
+ # both zero; check that signs match
+ elif copysign(1.0, x) == copysign(1.0, y):
+ return
+ else:
+ msg += ': zeros have different signs'
+ self.fail(msg.format(x, y))
+
def assertClose(self, x, y, eps=1e-9):
"""Return true iff complexes x and y "are close"."""
self.assertCloseAbs(x.real, y.real, eps)
@@ -93,6 +280,7 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
q = z.__truediv__(y)
self.assertClose(q, x)
+ @slowTest
def test_truediv(self):
simple_real = [float(i) for i in range(-5, 6)]
simple_complex = [complex(x, y) for x in simple_real for y in simple_real]
@@ -338,7 +526,10 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
def test_boolcontext(self):
for i in range(100):
- self.assertTrue(complex(random() + 1e-6, random() + 1e-6))
+ with torch._dynamo.error_on_graph_break(False):
+ r1 = random()
+ r2 = random()
+ self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6))
self.assertTrue(not complex(0.0, 0.0))
self.assertTrue(1j)
@@ -431,12 +622,13 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
self.assertRaises(TypeError, complex, WithComplex(1), object())
self.assertRaises(TypeError, complex, WithComplex(None), object())
- class EvilExc(Exception):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class EvilExc(Exception):
+ pass
- class evilcomplex:
- def __complex__(self):
- raise EvilExc
+ class evilcomplex:
+ def __complex__(self):
+ raise EvilExc
self.assertRaises(EvilExc, complex, evilcomplex())
@@ -460,31 +652,33 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
self.assertRaises(TypeError, complex, WithIndex(None), 1.5)
self.assertRaises(TypeError, complex, 1.5, WithIndex(None))
- class MyInt:
- def __int__(self):
- return 42
+ with torch._dynamo.error_on_graph_break(False):
+ class MyInt:
+ def __int__(self):
+ return 42
self.assertRaises(TypeError, complex, MyInt())
self.assertRaises(TypeError, complex, MyInt(), 1.5)
self.assertRaises(TypeError, complex, 1.5, MyInt())
- class complex0(complex):
- """Test usage of __complex__() when inheriting from 'complex'"""
- def __complex__(self):
- return 42j
-
- class complex1(complex):
- """Test usage of __complex__() with a __new__() method"""
- def __new__(self, value=0j):
- return complex.__new__(self, 2*value)
- def __complex__(self):
- return self
-
- class complex2(complex):
- """Make sure that __complex__() calls fail if anything other than a
- complex is returned"""
- def __complex__(self):
- return None
+ with torch._dynamo.error_on_graph_break(False):
+ class complex0(complex):
+ """Test usage of __complex__() when inheriting from 'complex'"""
+ def __complex__(self):
+ return 42j
+
+ class complex1(complex):
+ """Test usage of __complex__() with a __new__() method"""
+ def __new__(self, value=0j):
+ return complex.__new__(self, 2*value)
+ def __complex__(self):
+ return self
+
+ class complex2(complex):
+ """Make sure that __complex__() calls fail if anything other than a
+ complex is returned"""
+ def __complex__(self):
+ return None
check(complex(complex0(1j)), 0.0, 42.0)
with self.assertWarns(DeprecationWarning):
@@ -855,4 +1049,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
if __name__ == "__main__":
- unittest.main()
+ run_tests()