mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph. I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet). cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739 Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
839 lines
29 KiB
Diff
839 lines
29 KiB
Diff
diff --git a/test/dynamo/cpython/3_13/test_set.py b/test/dynamo/cpython/3_13/test_set.py
|
|
index d9102eb98a5..c8ee5ca451f 100644
|
|
--- a/test/dynamo/cpython/3_13/test_set.py
|
|
+++ b/test/dynamo/cpython/3_13/test_set.py
|
|
@@ -1,3 +1,56 @@
|
|
+# ======= BEGIN Dynamo patch =======
|
|
+# Owner(s): ["module: dynamo"]
|
|
+
|
|
+# ruff: noqa
|
|
+# flake8: noqa
|
|
+
|
|
+# Test copied from
|
|
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_set.py
|
|
+
|
|
+import sys
|
|
+import torch
|
|
+import torch._dynamo.test_case
|
|
+import unittest
|
|
+from torch._dynamo.test_case import CPythonTestCase
|
|
+from torch.testing._internal.common_utils import run_tests
|
|
+
|
|
+__TestCase = CPythonTestCase
|
|
+
|
|
+# redirect import statements
|
|
+import sys
|
|
+import importlib.abc
|
|
+
|
|
+redirect_imports = (
|
|
+ "test.mapping_tests",
|
|
+ "test.typinganndata",
|
|
+ "test.test_grammar",
|
|
+ "test.test_math",
|
|
+ "test.test_iter",
|
|
+ "test.typinganndata.ann_module",
|
|
+)
|
|
+
|
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
|
+ def find_spec(self, fullname, path, target=None):
|
|
+ # Check if the import is the problematic one
|
|
+ if fullname in redirect_imports:
|
|
+ try:
|
|
+ # Attempt to import the standalone module
|
|
+ name = fullname.removeprefix("test.")
|
|
+ r = importlib.import_module(name)
|
|
+ # Redirect the module in sys.modules
|
|
+ sys.modules[fullname] = r
|
|
+ # Return a module spec from the found module
|
|
+ return importlib.util.find_spec(name)
|
|
+ except ImportError:
|
|
+ return None
|
|
+ return None
|
|
+
|
|
+# Add the custom finder to sys.meta_path
|
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
|
+
|
|
+
|
|
+# ======= END DYNAMO PATCH =======
|
|
+
|
|
import unittest
|
|
from test import support
|
|
from test.support import warnings_helper
|
|
@@ -38,7 +91,7 @@ class HashCountingInt(int):
|
|
self.hash_count += 1
|
|
return int.__hash__(self)
|
|
|
|
-class TestJointOps:
|
|
+class _TestJointOps:
|
|
# Tests common to both set and frozenset
|
|
|
|
def setUp(self):
|
|
@@ -47,6 +100,7 @@ class TestJointOps:
|
|
self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
|
self.s = self.thetype(word)
|
|
self.d = dict.fromkeys(word)
|
|
+ super().setUp()
|
|
|
|
def test_new_or_init(self):
|
|
self.assertRaises(TypeError, self.thetype, [], 2)
|
|
@@ -261,13 +315,14 @@ class TestJointOps:
|
|
self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
|
|
|
|
def test_deepcopy(self):
|
|
- class Tracer:
|
|
- def __init__(self, value):
|
|
- self.value = value
|
|
- def __hash__(self):
|
|
- return self.value
|
|
- def __deepcopy__(self, memo=None):
|
|
- return Tracer(self.value + 1)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Tracer:
|
|
+ def __init__(self, value):
|
|
+ self.value = value
|
|
+ def __hash__(self):
|
|
+ return self.value
|
|
+ def __deepcopy__(self, memo=None):
|
|
+ return Tracer(self.value + 1)
|
|
t = Tracer(10)
|
|
s = self.thetype([t])
|
|
dup = copy.deepcopy(s)
|
|
@@ -279,8 +334,9 @@ class TestJointOps:
|
|
|
|
def test_gc(self):
|
|
# Create a nest of cycles to exercise overall ref count check
|
|
- class A:
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class A:
|
|
+ pass
|
|
s = set(A() for i in range(1000))
|
|
for elem in s:
|
|
elem.cycle = s
|
|
@@ -289,9 +345,10 @@ class TestJointOps:
|
|
|
|
def test_subclass_with_custom_hash(self):
|
|
# Bug #1257731
|
|
- class H(self.thetype):
|
|
- def __hash__(self):
|
|
- return int(id(self) & 0x7fffffff)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class H(self.thetype):
|
|
+ def __hash__(self):
|
|
+ return int(id(self) & 0x7fffffff)
|
|
s=H()
|
|
f=set()
|
|
f.add(s)
|
|
@@ -342,8 +399,9 @@ class TestJointOps:
|
|
|
|
def test_container_iterator(self):
|
|
# Bug #3680: tp_traverse was not implemented for set iterator object
|
|
- class C(object):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class C(object):
|
|
+ pass
|
|
obj = C()
|
|
ref = weakref.ref(obj)
|
|
container = set([obj, 1])
|
|
@@ -355,7 +413,7 @@ class TestJointOps:
|
|
def test_free_after_iterating(self):
|
|
support.check_free_after_iterating(self, iter, self.thetype)
|
|
|
|
-class TestSet(TestJointOps, unittest.TestCase):
|
|
+class TestSet(_TestJointOps, __TestCase):
|
|
thetype = set
|
|
basetype = set
|
|
|
|
@@ -600,19 +658,20 @@ class TestSet(TestJointOps, unittest.TestCase):
|
|
self.assertRaises(ReferenceError, str, p)
|
|
|
|
def test_rich_compare(self):
|
|
- class TestRichSetCompare:
|
|
- def __gt__(self, some_set):
|
|
- self.gt_called = True
|
|
- return False
|
|
- def __lt__(self, some_set):
|
|
- self.lt_called = True
|
|
- return False
|
|
- def __ge__(self, some_set):
|
|
- self.ge_called = True
|
|
- return False
|
|
- def __le__(self, some_set):
|
|
- self.le_called = True
|
|
- return False
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class TestRichSetCompare:
|
|
+ def __gt__(self, some_set):
|
|
+ self.gt_called = True
|
|
+ return False
|
|
+ def __lt__(self, some_set):
|
|
+ self.lt_called = True
|
|
+ return False
|
|
+ def __ge__(self, some_set):
|
|
+ self.ge_called = True
|
|
+ return False
|
|
+ def __le__(self, some_set):
|
|
+ self.le_called = True
|
|
+ return False
|
|
|
|
# This first tries the builtin rich set comparison, which doesn't know
|
|
# how to handle the custom object. Upon returning NotImplemented, the
|
|
@@ -644,28 +703,31 @@ class TestSetSubclass(TestSet):
|
|
basetype = set
|
|
|
|
def test_keywords_in_subclass(self):
|
|
- class subclass(set):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass(set):
|
|
+ pass
|
|
u = subclass([1, 2])
|
|
self.assertIs(type(u), subclass)
|
|
self.assertEqual(set(u), {1, 2})
|
|
with self.assertRaises(TypeError):
|
|
subclass(sequence=())
|
|
|
|
- class subclass_with_init(set):
|
|
- def __init__(self, arg, newarg=None):
|
|
- super().__init__(arg)
|
|
- self.newarg = newarg
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass_with_init(set):
|
|
+ def __init__(self, arg, newarg=None):
|
|
+ super().__init__(arg)
|
|
+ self.newarg = newarg
|
|
u = subclass_with_init([1, 2], newarg=3)
|
|
self.assertIs(type(u), subclass_with_init)
|
|
self.assertEqual(set(u), {1, 2})
|
|
self.assertEqual(u.newarg, 3)
|
|
|
|
- class subclass_with_new(set):
|
|
- def __new__(cls, arg, newarg=None):
|
|
- self = super().__new__(cls, arg)
|
|
- self.newarg = newarg
|
|
- return self
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass_with_new(set):
|
|
+ def __new__(cls, arg, newarg=None):
|
|
+ self = super().__new__(cls, arg)
|
|
+ self.newarg = newarg
|
|
+ return self
|
|
u = subclass_with_new([1, 2])
|
|
self.assertIs(type(u), subclass_with_new)
|
|
self.assertEqual(set(u), {1, 2})
|
|
@@ -675,7 +737,7 @@ class TestSetSubclass(TestSet):
|
|
subclass_with_new([1, 2], newarg=3)
|
|
|
|
|
|
-class TestFrozenSet(TestJointOps, unittest.TestCase):
|
|
+class TestFrozenSet(_TestJointOps, __TestCase):
|
|
thetype = frozenset
|
|
basetype = frozenset
|
|
|
|
@@ -756,27 +818,30 @@ class TestFrozenSetSubclass(TestFrozenSet):
|
|
basetype = frozenset
|
|
|
|
def test_keywords_in_subclass(self):
|
|
- class subclass(frozenset):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass(frozenset):
|
|
+ pass
|
|
u = subclass([1, 2])
|
|
self.assertIs(type(u), subclass)
|
|
self.assertEqual(set(u), {1, 2})
|
|
with self.assertRaises(TypeError):
|
|
subclass(sequence=())
|
|
|
|
- class subclass_with_init(frozenset):
|
|
- def __init__(self, arg, newarg=None):
|
|
- self.newarg = newarg
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass_with_init(frozenset):
|
|
+ def __init__(self, arg, newarg=None):
|
|
+ self.newarg = newarg
|
|
u = subclass_with_init([1, 2], newarg=3)
|
|
self.assertIs(type(u), subclass_with_init)
|
|
self.assertEqual(set(u), {1, 2})
|
|
self.assertEqual(u.newarg, 3)
|
|
|
|
- class subclass_with_new(frozenset):
|
|
- def __new__(cls, arg, newarg=None):
|
|
- self = super().__new__(cls, arg)
|
|
- self.newarg = newarg
|
|
- return self
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class subclass_with_new(frozenset):
|
|
+ def __new__(cls, arg, newarg=None):
|
|
+ self = super().__new__(cls, arg)
|
|
+ self.newarg = newarg
|
|
+ return self
|
|
u = subclass_with_new([1, 2], newarg=3)
|
|
self.assertIs(type(u), subclass_with_new)
|
|
self.assertEqual(set(u), {1, 2})
|
|
@@ -811,10 +876,17 @@ class TestFrozenSetSubclass(TestFrozenSet):
|
|
class SetSubclassWithSlots(set):
|
|
__slots__ = ('x', 'y', '__dict__')
|
|
|
|
-class TestSetSubclassWithSlots(unittest.TestCase):
|
|
+class TestSetSubclassWithSlots(__TestCase):
|
|
thetype = SetSubclassWithSlots
|
|
- setUp = TestJointOps.setUp
|
|
- test_pickling = TestJointOps.test_pickling
|
|
+ test_pickling = _TestJointOps.test_pickling
|
|
+
|
|
+ def setUp(self):
|
|
+ self.word = word = 'simsalabim'
|
|
+ self.otherword = 'madagascar'
|
|
+ self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
|
+ self.s = self.thetype(word)
|
|
+ self.d = dict.fromkeys(word)
|
|
+ super().setUp()
|
|
|
|
class FrozenSetSubclassWithSlots(frozenset):
|
|
__slots__ = ('x', 'y', '__dict__')
|
|
@@ -828,7 +900,7 @@ empty_set = set()
|
|
|
|
#==============================================================================
|
|
|
|
-class TestBasicOps:
|
|
+class _TestBasicOps:
|
|
|
|
def test_repr(self):
|
|
if self.repr is not None:
|
|
@@ -934,7 +1006,7 @@ class TestBasicOps:
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
|
|
+class TestBasicOpsEmpty(_TestBasicOps, __TestCase):
|
|
def setUp(self):
|
|
self.case = "empty set"
|
|
self.values = []
|
|
@@ -942,10 +1014,11 @@ class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
|
|
self.dup = set(self.values)
|
|
self.length = 0
|
|
self.repr = "set()"
|
|
+ super().setUp()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
|
|
+class TestBasicOpsSingleton(_TestBasicOps, __TestCase):
|
|
def setUp(self):
|
|
self.case = "unit set (number)"
|
|
self.values = [3]
|
|
@@ -953,6 +1026,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
|
|
self.dup = set(self.values)
|
|
self.length = 1
|
|
self.repr = "{3}"
|
|
+ super().setUp()
|
|
|
|
def test_in(self):
|
|
self.assertIn(3, self.set)
|
|
@@ -962,7 +1036,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
|
|
+class TestBasicOpsTuple(_TestBasicOps, __TestCase):
|
|
def setUp(self):
|
|
self.case = "unit set (tuple)"
|
|
self.values = [(0, "zero")]
|
|
@@ -970,6 +1044,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
|
|
self.dup = set(self.values)
|
|
self.length = 1
|
|
self.repr = "{(0, 'zero')}"
|
|
+ super().setUp()
|
|
|
|
def test_in(self):
|
|
self.assertIn((0, "zero"), self.set)
|
|
@@ -979,7 +1054,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
|
|
+class TestBasicOpsTriple(_TestBasicOps, __TestCase):
|
|
def setUp(self):
|
|
self.case = "triple set"
|
|
self.values = [0, "zero", operator.add]
|
|
@@ -987,36 +1062,39 @@ class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
|
|
self.dup = set(self.values)
|
|
self.length = 3
|
|
self.repr = None
|
|
+ super().setUp()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestBasicOpsString(TestBasicOps, unittest.TestCase):
|
|
+class TestBasicOpsString(_TestBasicOps, __TestCase):
|
|
def setUp(self):
|
|
self.case = "string set"
|
|
self.values = ["a", "b", "c"]
|
|
self.set = set(self.values)
|
|
self.dup = set(self.values)
|
|
self.length = 3
|
|
+ super().setUp()
|
|
|
|
def test_repr(self):
|
|
self.check_repr_against_values()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestBasicOpsBytes(TestBasicOps, unittest.TestCase):
|
|
+class TestBasicOpsBytes(_TestBasicOps, __TestCase):
|
|
def setUp(self):
|
|
self.case = "bytes set"
|
|
self.values = [b"a", b"b", b"c"]
|
|
self.set = set(self.values)
|
|
self.dup = set(self.values)
|
|
self.length = 3
|
|
+ super().setUp()
|
|
|
|
def test_repr(self):
|
|
self.check_repr_against_values()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
|
|
+class TestBasicOpsMixedStringBytes(_TestBasicOps, __TestCase):
|
|
def setUp(self):
|
|
self.enterContext(warnings_helper.check_warnings())
|
|
warnings.simplefilter('ignore', BytesWarning)
|
|
@@ -1025,6 +1103,7 @@ class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
|
|
self.set = set(self.values)
|
|
self.dup = set(self.values)
|
|
self.length = 4
|
|
+ super().setUp()
|
|
|
|
def test_repr(self):
|
|
self.check_repr_against_values()
|
|
@@ -1038,7 +1117,7 @@ def baditer():
|
|
def gooditer():
|
|
yield True
|
|
|
|
-class TestExceptionPropagation(unittest.TestCase):
|
|
+class TestExceptionPropagation(__TestCase):
|
|
"""SF 628246: Set constructor should not trap iterator TypeErrors"""
|
|
|
|
def test_instanceWithException(self):
|
|
@@ -1065,7 +1144,7 @@ class TestExceptionPropagation(unittest.TestCase):
|
|
|
|
#==============================================================================
|
|
|
|
-class TestSetOfSets(unittest.TestCase):
|
|
+class TestSetOfSets(__TestCase):
|
|
def test_constructor(self):
|
|
inner = frozenset([1])
|
|
outer = set([inner])
|
|
@@ -1078,9 +1157,10 @@ class TestSetOfSets(unittest.TestCase):
|
|
|
|
#==============================================================================
|
|
|
|
-class TestBinaryOps(unittest.TestCase):
|
|
+class TestBinaryOps(__TestCase):
|
|
def setUp(self):
|
|
self.set = set((2, 4, 6))
|
|
+ super().setUp()
|
|
|
|
def test_eq(self): # SF bug 643115
|
|
self.assertEqual(self.set, set({2:1,4:3,6:5}))
|
|
@@ -1151,9 +1231,10 @@ class TestBinaryOps(unittest.TestCase):
|
|
|
|
#==============================================================================
|
|
|
|
-class TestUpdateOps(unittest.TestCase):
|
|
+class TestUpdateOps(__TestCase):
|
|
def setUp(self):
|
|
self.set = set((2, 4, 6))
|
|
+ super().setUp()
|
|
|
|
def test_union_subset(self):
|
|
self.set |= set([2])
|
|
@@ -1237,10 +1318,11 @@ class TestUpdateOps(unittest.TestCase):
|
|
|
|
#==============================================================================
|
|
|
|
-class TestMutate(unittest.TestCase):
|
|
+class TestMutate(__TestCase):
|
|
def setUp(self):
|
|
self.values = ["a", "b", "c"]
|
|
self.set = set(self.values)
|
|
+ super().setUp()
|
|
|
|
def test_add_present(self):
|
|
self.set.add("c")
|
|
@@ -1311,7 +1393,7 @@ class TestMutate(unittest.TestCase):
|
|
|
|
#==============================================================================
|
|
|
|
-class TestSubsets:
|
|
+class _TestSubsets:
|
|
|
|
case2method = {"<=": "issubset",
|
|
">=": "issuperset",
|
|
@@ -1334,22 +1416,22 @@ class TestSubsets:
|
|
result = eval("x" + case + "y", locals())
|
|
self.assertEqual(result, expected)
|
|
# Test the "friendly" method-name spelling, if one exists.
|
|
- if case in TestSubsets.case2method:
|
|
- method = getattr(x, TestSubsets.case2method[case])
|
|
+ if case in _TestSubsets.case2method:
|
|
+ method = getattr(x, _TestSubsets.case2method[case])
|
|
result = method(y)
|
|
self.assertEqual(result, expected)
|
|
|
|
# Now do the same for the operands reversed.
|
|
- rcase = TestSubsets.reverse[case]
|
|
+ rcase = _TestSubsets.reverse[case]
|
|
result = eval("y" + rcase + "x", locals())
|
|
self.assertEqual(result, expected)
|
|
- if rcase in TestSubsets.case2method:
|
|
- method = getattr(y, TestSubsets.case2method[rcase])
|
|
+ if rcase in _TestSubsets.case2method:
|
|
+ method = getattr(y, _TestSubsets.case2method[rcase])
|
|
result = method(x)
|
|
self.assertEqual(result, expected)
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
|
|
+class TestSubsetEqualEmpty(_TestSubsets, __TestCase):
|
|
left = set()
|
|
right = set()
|
|
name = "both empty"
|
|
@@ -1357,7 +1439,7 @@ class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
|
|
+class TestSubsetEqualNonEmpty(_TestSubsets, __TestCase):
|
|
left = set([1, 2])
|
|
right = set([1, 2])
|
|
name = "equal pair"
|
|
@@ -1365,7 +1447,7 @@ class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
|
|
+class TestSubsetEmptyNonEmpty(_TestSubsets, __TestCase):
|
|
left = set()
|
|
right = set([1, 2])
|
|
name = "one empty, one non-empty"
|
|
@@ -1373,7 +1455,7 @@ class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestSubsetPartial(TestSubsets, unittest.TestCase):
|
|
+class TestSubsetPartial(_TestSubsets, __TestCase):
|
|
left = set([1])
|
|
right = set([1, 2])
|
|
name = "one a non-empty proper subset of other"
|
|
@@ -1381,7 +1463,7 @@ class TestSubsetPartial(TestSubsets, unittest.TestCase):
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
|
|
+class TestSubsetNonOverlap(_TestSubsets, __TestCase):
|
|
left = set([1])
|
|
right = set([2])
|
|
name = "neither empty, neither contains"
|
|
@@ -1389,7 +1471,7 @@ class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
|
|
|
|
#==============================================================================
|
|
|
|
-class TestOnlySetsInBinaryOps:
|
|
+class _TestOnlySetsInBinaryOps:
|
|
|
|
def test_eq_ne(self):
|
|
# Unlike the others, this is testing that == and != *are* allowed.
|
|
@@ -1505,47 +1587,52 @@ class TestOnlySetsInBinaryOps:
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestOnlySetsNumeric(TestOnlySetsInBinaryOps, unittest.TestCase):
|
|
+class TestOnlySetsNumeric(_TestOnlySetsInBinaryOps, __TestCase):
|
|
def setUp(self):
|
|
self.set = set((1, 2, 3))
|
|
self.other = 19
|
|
self.otherIsIterable = False
|
|
+ super().setUp()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestOnlySetsDict(TestOnlySetsInBinaryOps, unittest.TestCase):
|
|
+class TestOnlySetsDict(_TestOnlySetsInBinaryOps, __TestCase):
|
|
def setUp(self):
|
|
self.set = set((1, 2, 3))
|
|
self.other = {1:2, 3:4}
|
|
self.otherIsIterable = True
|
|
+ super().setUp()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestOnlySetsOperator(TestOnlySetsInBinaryOps, unittest.TestCase):
|
|
+class TestOnlySetsOperator(_TestOnlySetsInBinaryOps, __TestCase):
|
|
def setUp(self):
|
|
self.set = set((1, 2, 3))
|
|
self.other = operator.add
|
|
self.otherIsIterable = False
|
|
+ super().setUp()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestOnlySetsTuple(TestOnlySetsInBinaryOps, unittest.TestCase):
|
|
+class TestOnlySetsTuple(_TestOnlySetsInBinaryOps, __TestCase):
|
|
def setUp(self):
|
|
self.set = set((1, 2, 3))
|
|
self.other = (2, 4, 6)
|
|
self.otherIsIterable = True
|
|
+ super().setUp()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestOnlySetsString(TestOnlySetsInBinaryOps, unittest.TestCase):
|
|
+class TestOnlySetsString(_TestOnlySetsInBinaryOps, __TestCase):
|
|
def setUp(self):
|
|
self.set = set((1, 2, 3))
|
|
self.other = 'abc'
|
|
self.otherIsIterable = True
|
|
+ super().setUp()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
|
|
+class TestOnlySetsGenerator(_TestOnlySetsInBinaryOps, __TestCase):
|
|
def setUp(self):
|
|
def gen():
|
|
for i in range(0, 10, 2):
|
|
@@ -1553,10 +1640,11 @@ class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
|
|
self.set = set((1, 2, 3))
|
|
self.other = gen()
|
|
self.otherIsIterable = True
|
|
+ super().setUp()
|
|
|
|
#==============================================================================
|
|
|
|
-class TestCopying:
|
|
+class _TestCopying:
|
|
|
|
def test_copy(self):
|
|
dup = self.set.copy()
|
|
@@ -1577,40 +1665,46 @@ class TestCopying:
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestCopyingEmpty(TestCopying, unittest.TestCase):
|
|
+class TestCopyingEmpty(_TestCopying, __TestCase):
|
|
def setUp(self):
|
|
self.set = set()
|
|
+ super().setUp()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestCopyingSingleton(TestCopying, unittest.TestCase):
|
|
+class TestCopyingSingleton(_TestCopying, __TestCase):
|
|
def setUp(self):
|
|
self.set = set(["hello"])
|
|
+ super().setUp()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestCopyingTriple(TestCopying, unittest.TestCase):
|
|
+class TestCopyingTriple(_TestCopying, __TestCase):
|
|
def setUp(self):
|
|
self.set = set(["zero", 0, None])
|
|
+ super().setUp()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestCopyingTuple(TestCopying, unittest.TestCase):
|
|
+class TestCopyingTuple(_TestCopying, __TestCase):
|
|
def setUp(self):
|
|
self.set = set([(1, 2)])
|
|
+ super().setUp()
|
|
|
|
#------------------------------------------------------------------------------
|
|
|
|
-class TestCopyingNested(TestCopying, unittest.TestCase):
|
|
+class TestCopyingNested(_TestCopying, __TestCase):
|
|
def setUp(self):
|
|
self.set = set([((1, 2), (3, 4))])
|
|
+ super().setUp()
|
|
|
|
#==============================================================================
|
|
|
|
-class TestIdentities(unittest.TestCase):
|
|
+class TestIdentities(__TestCase):
|
|
def setUp(self):
|
|
self.a = set('abracadabra')
|
|
self.b = set('alacazam')
|
|
+ super().setUp()
|
|
|
|
def test_binopsVsSubsets(self):
|
|
a, b = self.a, self.b
|
|
@@ -1727,7 +1821,7 @@ def L(seqn):
|
|
'Test multiple tiers of iterators'
|
|
return chain(map(lambda x:x, R(Ig(G(seqn)))))
|
|
|
|
-class TestVariousIteratorArgs(unittest.TestCase):
|
|
+class TestVariousIteratorArgs(__TestCase):
|
|
|
|
def test_constructor(self):
|
|
for cons in (set, frozenset):
|
|
@@ -1785,7 +1879,7 @@ class bad_dict_clear:
|
|
def __hash__(self):
|
|
return 0
|
|
|
|
-class TestWeirdBugs(unittest.TestCase):
|
|
+class TestWeirdBugs(__TestCase):
|
|
def test_8420_set_merge(self):
|
|
# This used to segfault
|
|
global be_bad, set2, dict2
|
|
@@ -1813,12 +1907,13 @@ class TestWeirdBugs(unittest.TestCase):
|
|
list(si)
|
|
|
|
def test_merge_and_mutate(self):
|
|
- class X:
|
|
- def __hash__(self):
|
|
- return hash(0)
|
|
- def __eq__(self, o):
|
|
- other.clear()
|
|
- return False
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class X:
|
|
+ def __hash__(self):
|
|
+ return hash(0)
|
|
+ def __eq__(self, o):
|
|
+ other.clear()
|
|
+ return False
|
|
|
|
other = set()
|
|
other = {X() for i in range(10)}
|
|
@@ -1826,24 +1921,25 @@ class TestWeirdBugs(unittest.TestCase):
|
|
s.update(other)
|
|
|
|
|
|
-class TestOperationsMutating:
|
|
+class _TestOperationsMutating:
|
|
"""Regression test for bpo-46615"""
|
|
|
|
constructor1 = None
|
|
constructor2 = None
|
|
|
|
def make_sets_of_bad_objects(self):
|
|
- class Bad:
|
|
- def __eq__(self, other):
|
|
- if not enabled:
|
|
- return False
|
|
- if randrange(20) == 0:
|
|
- set1.clear()
|
|
- if randrange(20) == 0:
|
|
- set2.clear()
|
|
- return bool(randrange(2))
|
|
- def __hash__(self):
|
|
- return randrange(2)
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Bad:
|
|
+ def __eq__(self, other):
|
|
+ if not enabled:
|
|
+ return False
|
|
+ if randrange(20) == 0:
|
|
+ set1.clear()
|
|
+ if randrange(20) == 0:
|
|
+ set2.clear()
|
|
+ return bool(randrange(2))
|
|
+ def __hash__(self):
|
|
+ return randrange(2)
|
|
# Don't behave poorly during construction.
|
|
enabled = False
|
|
set1 = self.constructor1(Bad() for _ in range(randrange(50)))
|
|
@@ -1862,7 +1958,7 @@ class TestOperationsMutating:
|
|
self.assertIn("changed size during iteration", str(e))
|
|
|
|
|
|
-class TestBinaryOpsMutating(TestOperationsMutating):
|
|
+class _TestBinaryOpsMutating(_TestOperationsMutating):
|
|
|
|
def test_eq_with_mutation(self):
|
|
self.check_set_op_does_not_crash(lambda a, b: a == b)
|
|
@@ -1933,24 +2029,24 @@ class TestBinaryOpsMutating(TestOperationsMutating):
|
|
self.check_set_op_does_not_crash(f3)
|
|
|
|
|
|
-class TestBinaryOpsMutating_Set_Set(TestBinaryOpsMutating, unittest.TestCase):
|
|
+class TestBinaryOpsMutating_Set_Set(_TestBinaryOpsMutating, __TestCase):
|
|
constructor1 = set
|
|
constructor2 = set
|
|
|
|
-class TestBinaryOpsMutating_Subclass_Subclass(TestBinaryOpsMutating, unittest.TestCase):
|
|
+class TestBinaryOpsMutating_Subclass_Subclass(_TestBinaryOpsMutating, __TestCase):
|
|
constructor1 = SetSubclass
|
|
constructor2 = SetSubclass
|
|
|
|
-class TestBinaryOpsMutating_Set_Subclass(TestBinaryOpsMutating, unittest.TestCase):
|
|
+class TestBinaryOpsMutating_Set_Subclass(_TestBinaryOpsMutating, __TestCase):
|
|
constructor1 = set
|
|
constructor2 = SetSubclass
|
|
|
|
-class TestBinaryOpsMutating_Subclass_Set(TestBinaryOpsMutating, unittest.TestCase):
|
|
+class TestBinaryOpsMutating_Subclass_Set(_TestBinaryOpsMutating, __TestCase):
|
|
constructor1 = SetSubclass
|
|
constructor2 = set
|
|
|
|
|
|
-class TestMethodsMutating(TestOperationsMutating):
|
|
+class _TestMethodsMutating(_TestOperationsMutating):
|
|
|
|
def test_issubset_with_mutation(self):
|
|
self.check_set_op_does_not_crash(set.issubset)
|
|
@@ -1986,27 +2082,27 @@ class TestMethodsMutating(TestOperationsMutating):
|
|
self.check_set_op_does_not_crash(set.update)
|
|
|
|
|
|
-class TestMethodsMutating_Set_Set(TestMethodsMutating, unittest.TestCase):
|
|
+class TestMethodsMutating_Set_Set(_TestMethodsMutating, __TestCase):
|
|
constructor1 = set
|
|
constructor2 = set
|
|
|
|
-class TestMethodsMutating_Subclass_Subclass(TestMethodsMutating, unittest.TestCase):
|
|
+class TestMethodsMutating_Subclass_Subclass(_TestMethodsMutating, __TestCase):
|
|
constructor1 = SetSubclass
|
|
constructor2 = SetSubclass
|
|
|
|
-class TestMethodsMutating_Set_Subclass(TestMethodsMutating, unittest.TestCase):
|
|
+class TestMethodsMutating_Set_Subclass(_TestMethodsMutating, __TestCase):
|
|
constructor1 = set
|
|
constructor2 = SetSubclass
|
|
|
|
-class TestMethodsMutating_Subclass_Set(TestMethodsMutating, unittest.TestCase):
|
|
+class TestMethodsMutating_Subclass_Set(_TestMethodsMutating, __TestCase):
|
|
constructor1 = SetSubclass
|
|
constructor2 = set
|
|
|
|
-class TestMethodsMutating_Set_Dict(TestMethodsMutating, unittest.TestCase):
|
|
+class TestMethodsMutating_Set_Dict(_TestMethodsMutating, __TestCase):
|
|
constructor1 = set
|
|
constructor2 = dict.fromkeys
|
|
|
|
-class TestMethodsMutating_Set_List(TestMethodsMutating, unittest.TestCase):
|
|
+class TestMethodsMutating_Set_List(_TestMethodsMutating, __TestCase):
|
|
constructor1 = set
|
|
constructor2 = list
|
|
|
|
@@ -2068,7 +2164,7 @@ def faces(G):
|
|
return f
|
|
|
|
|
|
-class TestGraphs(unittest.TestCase):
|
|
+class TestGraphs(__TestCase):
|
|
|
|
def test_cube(self):
|
|
|
|
@@ -2118,4 +2214,4 @@ class TestGraphs(unittest.TestCase):
|
|
#==============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
- unittest.main()
|
|
+ run_tests()
|