mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph. I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet). cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739 Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
191 lines
6.5 KiB
Diff
191 lines
6.5 KiB
Diff
diff --git a/test/dynamo/cpython/3_13/test_cmath.py b/test/dynamo/cpython/3_13/test_cmath.py
|
|
index a96a5780b31..d00dfca8a17 100644
|
|
--- a/test/dynamo/cpython/3_13/test_cmath.py
|
|
+++ b/test/dynamo/cpython/3_13/test_cmath.py
|
|
@@ -1,5 +1,58 @@
|
|
+# ======= BEGIN Dynamo patch =======
|
|
+# Owner(s): ["module: dynamo"]
|
|
+
|
|
+# ruff: noqa
|
|
+# flake8: noqa
|
|
+
|
|
+# Test copied from
|
|
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_cmath.py
|
|
+
|
|
+import sys
|
|
+import torch
|
|
+import torch._dynamo.test_case
|
|
+import unittest
|
|
+from torch._dynamo.test_case import CPythonTestCase
|
|
+from torch.testing._internal.common_utils import run_tests
|
|
+
|
|
+__TestCase = CPythonTestCase
|
|
+
|
|
+
|
|
+# redirect import statements
|
|
+import sys
|
|
+import importlib.abc
|
|
+
|
|
+redirect_imports = (
|
|
+ "test.mapping_tests",
|
|
+ "test.typinganndata",
|
|
+ "test.test_grammar",
|
|
+ "test.test_math",
|
|
+ "test.test_iter",
|
|
+ "test.typinganndata.ann_module",
|
|
+)
|
|
+
|
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
|
+ def find_spec(self, fullname, path, target=None):
|
|
+ # Check if the import is the problematic one
|
|
+ if fullname in redirect_imports:
|
|
+ try:
|
|
+ # Attempt to import the standalone module
|
|
+ name = fullname.removeprefix("test.")
|
|
+ r = importlib.import_module(name)
|
|
+ # Redirect the module in sys.modules
|
|
+ sys.modules[fullname] = r
|
|
+ # Return a module spec from the found module
|
|
+ return importlib.util.find_spec(name)
|
|
+ except ImportError:
|
|
+ return None
|
|
+ return None
|
|
+
|
|
+# Add the custom finder to sys.meta_path
|
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
|
+
|
|
+
|
|
+# ======= END DYNAMO PATCH =======
|
|
+
|
|
from test.support import requires_IEEE_754, cpython_only, import_helper
|
|
-from test.support.testcase import ComplexesAreIdenticalMixin
|
|
from test.test_math import parse_testfile, test_file
|
|
import test.test_math as test_math
|
|
import unittest
|
|
@@ -50,7 +103,7 @@ complex_nans = [complex(x, y) for x, y in [
|
|
(INF, NAN)
|
|
]]
|
|
|
|
-class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase):
|
|
+class CMathTests(__TestCase):
|
|
# list of all functions in cmath
|
|
test_functions = [getattr(cmath, fname) for fname in [
|
|
'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh',
|
|
@@ -66,6 +119,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase):
|
|
def tearDown(self):
|
|
self.test_values.close()
|
|
|
|
+ def assertFloatIdentical(self, x, y):
|
|
+ """Fail unless floats x and y are identical, in the sense that:
|
|
+ (1) both x and y are nans, or
|
|
+ (2) both x and y are infinities, with the same sign, or
|
|
+ (3) both x and y are zeros, with the same sign, or
|
|
+ (4) x and y are both finite and nonzero, and x == y
|
|
+
|
|
+ """
|
|
+ msg = 'floats {!r} and {!r} are not identical'
|
|
+
|
|
+ if math.isnan(x) or math.isnan(y):
|
|
+ if math.isnan(x) and math.isnan(y):
|
|
+ return
|
|
+ elif x == y:
|
|
+ if x != 0.0:
|
|
+ return
|
|
+ # both zero; check that signs match
|
|
+ elif math.copysign(1.0, x) == math.copysign(1.0, y):
|
|
+ return
|
|
+ else:
|
|
+ msg += ': zeros have different signs'
|
|
+ self.fail(msg.format(x, y))
|
|
+
|
|
+ def assertComplexesAreIdentical(self, x, y):
|
|
+ """Fail unless complex numbers x and y have equal values and signs.
|
|
+
|
|
+ In particular, if x and y both have real (or imaginary) part
|
|
+ zero, but the zeros have different signs, this test will fail.
|
|
+
|
|
+ """
|
|
+ self.assertFloatIdentical(x.real, y.real)
|
|
+ self.assertFloatIdentical(x.imag, y.imag)
|
|
+
|
|
def rAssertAlmostEqual(self, a, b, rel_err = 2e-15, abs_err = 5e-323,
|
|
msg=None):
|
|
"""Fail if the two floating-point numbers are not almost equal.
|
|
@@ -165,38 +251,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase):
|
|
# end up being passed to the cmath functions
|
|
|
|
# usual case: new-style class implementing __complex__
|
|
- class MyComplex:
|
|
- def __init__(self, value):
|
|
- self.value = value
|
|
- def __complex__(self):
|
|
- return self.value
|
|
-
|
|
- # classes for which __complex__ raises an exception
|
|
- class SomeException(Exception):
|
|
- pass
|
|
- class MyComplexException:
|
|
- def __complex__(self):
|
|
- raise SomeException
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MyComplex:
|
|
+ def __init__(self, value):
|
|
+ self.value = value
|
|
+ def __complex__(self):
|
|
+ return self.value
|
|
+
|
|
+ # classes for which __complex__ raises an exception
|
|
+ class SomeException(Exception):
|
|
+ pass
|
|
+ class MyComplexException:
|
|
+ def __complex__(self):
|
|
+ raise SomeException
|
|
|
|
- # some classes not providing __float__ or __complex__
|
|
- class NeitherComplexNorFloat(object):
|
|
- pass
|
|
- class Index:
|
|
- def __int__(self): return 2
|
|
- def __index__(self): return 2
|
|
- class MyInt:
|
|
- def __int__(self): return 2
|
|
-
|
|
- # other possible combinations of __float__ and __complex__
|
|
- # that should work
|
|
- class FloatAndComplex:
|
|
- def __float__(self):
|
|
- return flt_arg
|
|
- def __complex__(self):
|
|
- return cx_arg
|
|
- class JustFloat:
|
|
- def __float__(self):
|
|
- return flt_arg
|
|
+ # some classes not providing __float__ or __complex__
|
|
+ class NeitherComplexNorFloat(object):
|
|
+ pass
|
|
+ class Index:
|
|
+ def __int__(self): return 2
|
|
+ def __index__(self): return 2
|
|
+ class MyInt:
|
|
+ def __int__(self): return 2
|
|
+
|
|
+ # other possible combinations of __float__ and __complex__
|
|
+ # that should work
|
|
+ class FloatAndComplex:
|
|
+ def __float__(self):
|
|
+ return flt_arg
|
|
+ def __complex__(self):
|
|
+ return cx_arg
|
|
+ class JustFloat:
|
|
+ def __float__(self):
|
|
+ return flt_arg
|
|
|
|
for f in self.test_functions:
|
|
# usual usage
|
|
@@ -590,4 +677,4 @@ class IsCloseTests(test_math.IsCloseTests):
|
|
|
|
|
|
if __name__ == "__main__":
|
|
- unittest.main()
|
|
+ run_tests()
|