mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph. I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet). cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739 Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
485 lines
16 KiB
Diff
485 lines
16 KiB
Diff
diff --git a/test/dynamo/cpython/3_13/test_float.py b/test/dynamo/cpython/3_13/test_float.py
|
|
index 97f951f1299..da82bd190c3 100644
|
|
--- a/test/dynamo/cpython/3_13/test_float.py
|
|
+++ b/test/dynamo/cpython/3_13/test_float.py
|
|
@@ -1,3 +1,57 @@
|
|
+# ======= BEGIN Dynamo patch =======
|
|
+# Owner(s): ["module: dynamo"]
|
|
+
|
|
+# ruff: noqa
|
|
+# flake8: noqa
|
|
+
|
|
+# Test copied from
|
|
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_float.py
|
|
+
|
|
+import sys
|
|
+import torch
|
|
+import torch._dynamo.test_case
|
|
+import unittest
|
|
+from torch._dynamo.test_case import CPythonTestCase
|
|
+from torch.testing._internal.common_utils import run_tests
|
|
+
|
|
+__TestCase = CPythonTestCase
|
|
+
|
|
+
|
|
+# redirect import statements
|
|
+import sys
|
|
+import importlib.abc
|
|
+
|
|
+redirect_imports = (
|
|
+ "test.mapping_tests",
|
|
+ "test.typinganndata",
|
|
+ "test.test_grammar",
|
|
+ "test.test_math",
|
|
+ "test.test_iter",
|
|
+ "test.typinganndata.ann_module",
|
|
+)
|
|
+
|
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
|
+ def find_spec(self, fullname, path, target=None):
|
|
+ # Check if the import is the problematic one
|
|
+ if fullname in redirect_imports:
|
|
+ try:
|
|
+ # Attempt to import the standalone module
|
|
+ name = fullname.removeprefix("test.")
|
|
+ r = importlib.import_module(name)
|
|
+ # Redirect the module in sys.modules
|
|
+ sys.modules[fullname] = r
|
|
+ # Return a module spec from the found module
|
|
+ return importlib.util.find_spec(name)
|
|
+ except ImportError:
|
|
+ return None
|
|
+ return None
|
|
+
|
|
+# Add the custom finder to sys.meta_path
|
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
|
+
|
|
+
|
|
+# ======= END DYNAMO PATCH =======
|
|
+
|
|
import fractions
|
|
import operator
|
|
import os
|
|
@@ -8,11 +62,84 @@ import time
|
|
import unittest
|
|
|
|
from test import support
|
|
-from test.support.testcase import FloatsAreIdenticalMixin
|
|
-from test.support.numbers import (
|
|
- VALID_UNDERSCORE_LITERALS,
|
|
- INVALID_UNDERSCORE_LITERALS,
|
|
-)
|
|
+
|
|
+VALID_UNDERSCORE_LITERALS = [
|
|
+ '0_0_0',
|
|
+ '4_2',
|
|
+ '1_0000_0000',
|
|
+ '0b1001_0100',
|
|
+ '0xffff_ffff',
|
|
+ '0o5_7_7',
|
|
+ '1_00_00.5',
|
|
+ '1_00_00.5e5',
|
|
+ '1_00_00e5_1',
|
|
+ '1e1_0',
|
|
+ '.1_4',
|
|
+ '.1_4e1',
|
|
+ '0b_0',
|
|
+ '0x_f',
|
|
+ '0o_5',
|
|
+ '1_00_00j',
|
|
+ '1_00_00.5j',
|
|
+ '1_00_00e5_1j',
|
|
+ '.1_4j',
|
|
+ '(1_2.5+3_3j)',
|
|
+ '(.5_6j)',
|
|
+]
|
|
+INVALID_UNDERSCORE_LITERALS = [
|
|
+ # Trailing underscores:
|
|
+ '0_',
|
|
+ '42_',
|
|
+ '1.4j_',
|
|
+ '0x_',
|
|
+ '0b1_',
|
|
+ '0xf_',
|
|
+ '0o5_',
|
|
+ '0 if 1_Else 1',
|
|
+ # Underscores in the base selector:
|
|
+ '0_b0',
|
|
+ '0_xf',
|
|
+ '0_o5',
|
|
+ # Old-style octal, still disallowed:
|
|
+ '0_7',
|
|
+ '09_99',
|
|
+ # Multiple consecutive underscores:
|
|
+ '4_______2',
|
|
+ '0.1__4',
|
|
+ '0.1__4j',
|
|
+ '0b1001__0100',
|
|
+ '0xffff__ffff',
|
|
+ '0x___',
|
|
+ '0o5__77',
|
|
+ '1e1__0',
|
|
+ '1e1__0j',
|
|
+ # Underscore right before a dot:
|
|
+ '1_.4',
|
|
+ '1_.4j',
|
|
+ # Underscore right after a dot:
|
|
+ '1._4',
|
|
+ '1._4j',
|
|
+ '._5',
|
|
+ '._5j',
|
|
+ # Underscore right after a sign:
|
|
+ '1.0e+_1',
|
|
+ '1.0e+_1j',
|
|
+ # Underscore right before j:
|
|
+ '1.4_j',
|
|
+ '1.4e5_j',
|
|
+ # Underscore right before e:
|
|
+ '1_e1',
|
|
+ '1.4_e1',
|
|
+ '1.4_e1j',
|
|
+ # Underscore right after e:
|
|
+ '1e_1',
|
|
+ '1.4e_1',
|
|
+ '1.4e_1j',
|
|
+ # Complex cases with parens:
|
|
+ '(1+1.5_j_)',
|
|
+ '(1+1.5_j)',
|
|
+]
|
|
+
|
|
from math import isinf, isnan, copysign, ldexp
|
|
import math
|
|
|
|
@@ -35,7 +162,7 @@ class FloatSubclass(float):
|
|
class OtherFloatSubclass(float):
|
|
pass
|
|
|
|
-class GeneralFloatCases(unittest.TestCase):
|
|
+class GeneralFloatCases(__TestCase):
|
|
|
|
def test_float(self):
|
|
self.assertEqual(float(3.14), 3.14)
|
|
@@ -95,9 +222,10 @@ class GeneralFloatCases(unittest.TestCase):
|
|
def test_non_numeric_input_types(self):
|
|
# Test possible non-numeric types for the argument x, including
|
|
# subclasses of the explicitly documented accepted types.
|
|
- class CustomStr(str): pass
|
|
- class CustomBytes(bytes): pass
|
|
- class CustomByteArray(bytearray): pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class CustomStr(str): pass
|
|
+ class CustomBytes(bytes): pass
|
|
+ class CustomByteArray(bytearray): pass
|
|
|
|
factories = [
|
|
bytes,
|
|
@@ -184,30 +312,31 @@ class GeneralFloatCases(unittest.TestCase):
|
|
|
|
def test_floatconversion(self):
|
|
# Make sure that calls to __float__() work properly
|
|
- class Foo1(object):
|
|
- def __float__(self):
|
|
- return 42.
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Foo1(object):
|
|
+ def __float__(self):
|
|
+ return 42.
|
|
|
|
- class Foo2(float):
|
|
- def __float__(self):
|
|
- return 42.
|
|
+ class Foo2(float):
|
|
+ def __float__(self):
|
|
+ return 42.
|
|
|
|
- class Foo3(float):
|
|
- def __new__(cls, value=0.):
|
|
- return float.__new__(cls, 2*value)
|
|
+ class Foo3(float):
|
|
+ def __new__(cls, value=0.):
|
|
+ return float.__new__(cls, 2*value)
|
|
|
|
- def __float__(self):
|
|
- return self
|
|
+ def __float__(self):
|
|
+ return self
|
|
|
|
- class Foo4(float):
|
|
- def __float__(self):
|
|
- return 42
|
|
+ class Foo4(float):
|
|
+ def __float__(self):
|
|
+ return 42
|
|
|
|
- # Issue 5759: __float__ not called on str subclasses (though it is on
|
|
- # unicode subclasses).
|
|
- class FooStr(str):
|
|
- def __float__(self):
|
|
- return float(str(self)) + 1
|
|
+ # Issue 5759: __float__ not called on str subclasses (though it is on
|
|
+ # unicode subclasses).
|
|
+ class FooStr(str):
|
|
+ def __float__(self):
|
|
+ return float(str(self)) + 1
|
|
|
|
self.assertEqual(float(Foo1()), 42.)
|
|
self.assertEqual(float(Foo2()), 42.)
|
|
@@ -216,15 +345,17 @@ class GeneralFloatCases(unittest.TestCase):
|
|
self.assertRaises(TypeError, float, Foo4(42))
|
|
self.assertEqual(float(FooStr('8')), 9.)
|
|
|
|
- class Foo5:
|
|
- def __float__(self):
|
|
- return ""
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Foo5:
|
|
+ def __float__(self):
|
|
+ return ""
|
|
self.assertRaises(TypeError, time.sleep, Foo5())
|
|
|
|
- # Issue #24731
|
|
- class F:
|
|
- def __float__(self):
|
|
- return OtherFloatSubclass(42.)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ # Issue #24731
|
|
+ class F:
|
|
+ def __float__(self):
|
|
+ return OtherFloatSubclass(42.)
|
|
with self.assertWarns(DeprecationWarning):
|
|
self.assertEqual(float(F()), 42.)
|
|
with self.assertWarns(DeprecationWarning):
|
|
@@ -234,18 +365,20 @@ class GeneralFloatCases(unittest.TestCase):
|
|
with self.assertWarns(DeprecationWarning):
|
|
self.assertIs(type(FloatSubclass(F())), FloatSubclass)
|
|
|
|
- class MyIndex:
|
|
- def __init__(self, value):
|
|
- self.value = value
|
|
- def __index__(self):
|
|
- return self.value
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MyIndex:
|
|
+ def __init__(self, value):
|
|
+ self.value = value
|
|
+ def __index__(self):
|
|
+ return self.value
|
|
|
|
self.assertEqual(float(MyIndex(42)), 42.0)
|
|
self.assertRaises(OverflowError, float, MyIndex(2**2000))
|
|
|
|
- class MyInt:
|
|
- def __int__(self):
|
|
- return 42
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MyInt:
|
|
+ def __int__(self):
|
|
+ return 42
|
|
|
|
self.assertRaises(TypeError, float, MyInt())
|
|
|
|
@@ -254,27 +387,30 @@ class GeneralFloatCases(unittest.TestCase):
|
|
float(x='3.14')
|
|
|
|
def test_keywords_in_subclass(self):
|
|
- class subclass(float):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass(float):
|
|
+ pass
|
|
u = subclass(2.5)
|
|
self.assertIs(type(u), subclass)
|
|
self.assertEqual(float(u), 2.5)
|
|
with self.assertRaises(TypeError):
|
|
subclass(x=0)
|
|
|
|
- class subclass_with_init(float):
|
|
- def __init__(self, arg, newarg=None):
|
|
- self.newarg = newarg
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass_with_init(float):
|
|
+ def __init__(self, arg, newarg=None):
|
|
+ self.newarg = newarg
|
|
u = subclass_with_init(2.5, newarg=3)
|
|
self.assertIs(type(u), subclass_with_init)
|
|
self.assertEqual(float(u), 2.5)
|
|
self.assertEqual(u.newarg, 3)
|
|
|
|
- class subclass_with_new(float):
|
|
- def __new__(cls, arg, newarg=None):
|
|
- self = super().__new__(cls, arg)
|
|
- self.newarg = newarg
|
|
- return self
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass_with_new(float):
|
|
+ def __new__(cls, arg, newarg=None):
|
|
+ self = super().__new__(cls, arg)
|
|
+ self.newarg = newarg
|
|
+ return self
|
|
u = subclass_with_new(2.5, newarg=3)
|
|
self.assertIs(type(u), subclass_with_new)
|
|
self.assertEqual(float(u), 2.5)
|
|
@@ -610,17 +746,18 @@ class GeneralFloatCases(unittest.TestCase):
|
|
def test_hash_nan(self):
|
|
value = float('nan')
|
|
self.assertEqual(hash(value), object.__hash__(value))
|
|
- class H:
|
|
- def __hash__(self):
|
|
- return 42
|
|
- class F(float, H):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class H:
|
|
+ def __hash__(self):
|
|
+ return 42
|
|
+ class F(float, H):
|
|
+ pass
|
|
value = F('nan')
|
|
self.assertEqual(hash(value), object.__hash__(value))
|
|
|
|
|
|
@unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__")
|
|
-class FormatFunctionsTestCase(unittest.TestCase):
|
|
+class FormatFunctionsTestCase(__TestCase):
|
|
def test_getformat(self):
|
|
self.assertIn(float.__getformat__('double'),
|
|
['unknown', 'IEEE, big-endian', 'IEEE, little-endian'])
|
|
@@ -645,7 +782,7 @@ LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN))
|
|
# is accident (today).
|
|
# let's also try to guarantee that -0.0 and 0.0 don't get confused.
|
|
|
|
-class IEEEFormatTestCase(unittest.TestCase):
|
|
+class IEEEFormatTestCase(__TestCase):
|
|
|
|
@support.requires_IEEE_754
|
|
def test_double_specials_do_unpack(self):
|
|
@@ -670,7 +807,7 @@ class IEEEFormatTestCase(unittest.TestCase):
|
|
self.assertEqual(struct.pack("<f", 3.40282356e38), struct.pack("<f", FLT_MAX))
|
|
self.assertEqual(struct.pack("<f", -3.40282356e38), struct.pack("<f", -FLT_MAX))
|
|
|
|
-class FormatTestCase(unittest.TestCase):
|
|
+class FormatTestCase(__TestCase):
|
|
|
|
def test_format(self):
|
|
# these should be rewritten to use both format(x, spec) and
|
|
@@ -767,7 +904,7 @@ class FormatTestCase(unittest.TestCase):
|
|
self.assertEqual(format(-123.34, '00.10e'), '-1.2334000000e+02')
|
|
self.assertEqual(format(-123.34, '00.10g'), '-123.34')
|
|
|
|
-class ReprTestCase(unittest.TestCase):
|
|
+class ReprTestCase(__TestCase):
|
|
def test_repr(self):
|
|
with open(os.path.join(os.path.split(__file__)[0],
|
|
'mathdata',
|
|
@@ -832,7 +969,29 @@ class ReprTestCase(unittest.TestCase):
|
|
self.assertEqual(repr(float(negs)), str(float(negs)))
|
|
|
|
@support.requires_IEEE_754
|
|
-class RoundTestCase(unittest.TestCase, FloatsAreIdenticalMixin):
|
|
+class RoundTestCase(__TestCase):
|
|
+ def assertFloatsAreIdentical(self, x, y):
|
|
+ """assert that floats x and y are identical, in the sense that:
|
|
+ (1) both x and y are nans, or
|
|
+ (2) both x and y are infinities, with the same sign, or
|
|
+ (3) both x and y are zeros, with the same sign, or
|
|
+ (4) x and y are both finite and nonzero, and x == y
|
|
+
|
|
+ """
|
|
+ msg = 'floats {!r} and {!r} are not identical'
|
|
+
|
|
+ if isnan(x) or isnan(y):
|
|
+ if isnan(x) and isnan(y):
|
|
+ return
|
|
+ elif x == y:
|
|
+ if x != 0.0:
|
|
+ return
|
|
+ # both zero; check that signs match
|
|
+ elif copysign(1.0, x) == copysign(1.0, y):
|
|
+ return
|
|
+ else:
|
|
+ msg += ': zeros have different signs'
|
|
+ self.fail(msg.format(x, y))
|
|
|
|
def test_inf_nan(self):
|
|
self.assertRaises(OverflowError, round, INF)
|
|
@@ -955,7 +1114,7 @@ class RoundTestCase(unittest.TestCase, FloatsAreIdenticalMixin):
|
|
|
|
# Beginning with Python 2.6 float has cross platform compatible
|
|
# ways to create and represent inf and nan
|
|
-class InfNanTest(unittest.TestCase):
|
|
+class InfNanTest(__TestCase):
|
|
def test_inf_from_str(self):
|
|
self.assertTrue(isinf(float("inf")))
|
|
self.assertTrue(isinf(float("+inf")))
|
|
@@ -1056,12 +1215,35 @@ class InfNanTest(unittest.TestCase):
|
|
|
|
fromHex = float.fromhex
|
|
toHex = float.hex
|
|
-class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase):
|
|
+class HexFloatTestCase(__TestCase):
|
|
MAX = fromHex('0x.fffffffffffff8p+1024') # max normal
|
|
MIN = fromHex('0x1p-1022') # min normal
|
|
TINY = fromHex('0x0.0000000000001p-1022') # min subnormal
|
|
EPS = fromHex('0x0.0000000000001p0') # diff between 1.0 and next float up
|
|
|
|
+ def assertFloatsAreIdentical(self, x, y):
|
|
+ """assert that floats x and y are identical, in the sense that:
|
|
+ (1) both x and y are nans, or
|
|
+ (2) both x and y are infinities, with the same sign, or
|
|
+ (3) both x and y are zeros, with the same sign, or
|
|
+ (4) x and y are both finite and nonzero, and x == y
|
|
+
|
|
+ """
|
|
+ msg = 'floats {!r} and {!r} are not identical'
|
|
+
|
|
+ if isnan(x) or isnan(y):
|
|
+ if isnan(x) and isnan(y):
|
|
+ return
|
|
+ elif x == y:
|
|
+ if x != 0.0:
|
|
+ return
|
|
+ # both zero; check that signs match
|
|
+ elif copysign(1.0, x) == copysign(1.0, y):
|
|
+ return
|
|
+ else:
|
|
+ msg += ': zeros have different signs'
|
|
+ self.fail(msg.format(x, y))
|
|
+
|
|
def identical(self, x, y):
|
|
self.assertFloatsAreIdentical(x, y)
|
|
|
|
@@ -1482,17 +1664,19 @@ class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase):
|
|
self.identical(x, fromHex(toHex(x)))
|
|
|
|
def test_subclass(self):
|
|
- class F(float):
|
|
- def __new__(cls, value):
|
|
- return float.__new__(cls, value + 1)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class F(float):
|
|
+ def __new__(cls, value):
|
|
+ return float.__new__(cls, value + 1)
|
|
|
|
f = F.fromhex((1.5).hex())
|
|
self.assertIs(type(f), F)
|
|
self.assertEqual(f, 2.5)
|
|
|
|
- class F2(float):
|
|
- def __init__(self, value):
|
|
- self.foo = 'bar'
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class F2(float):
|
|
+ def __init__(self, value):
|
|
+ self.foo = 'bar'
|
|
|
|
f = F2.fromhex((1.5).hex())
|
|
self.assertIs(type(f), F2)
|
|
@@ -1500,5 +1684,5 @@ class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase):
|
|
self.assertEqual(getattr(f, 'foo', 'none'), 'bar')
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
- unittest.main()
|
|
+if __name__ == "__main__":
|
|
+ run_tests()
|