pytorch/test/dynamo/cpython/3_13/test_set.diff
William Wen 8678d831c4 [dynamo] rename set_fullgraph to error_on_graph_break (#161739)
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph.

I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet).

 cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739
Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
2025-09-04 01:15:06 +00:00

839 lines
29 KiB
Diff

diff --git a/test/dynamo/cpython/3_13/test_set.py b/test/dynamo/cpython/3_13/test_set.py
index d9102eb98a5..c8ee5ca451f 100644
--- a/test/dynamo/cpython/3_13/test_set.py
+++ b/test/dynamo/cpython/3_13/test_set.py
@@ -1,3 +1,56 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_set.py
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import run_tests
+
+__TestCase = CPythonTestCase
+
+# redirect import statements
+import sys
+import importlib.abc
+
+redirect_imports = (
+ "test.mapping_tests",
+ "test.typinganndata",
+ "test.test_grammar",
+ "test.test_math",
+ "test.test_iter",
+ "test.typinganndata.ann_module",
+)
+
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
+ if fullname in redirect_imports:
+ try:
+ # Attempt to import the standalone module
+ name = fullname.removeprefix("test.")
+ r = importlib.import_module(name)
+ # Redirect the module in sys.modules
+ sys.modules[fullname] = r
+ # Return a module spec from the found module
+ return importlib.util.find_spec(name)
+ except ImportError:
+ return None
+ return None
+
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
+
import unittest
from test import support
from test.support import warnings_helper
@@ -38,7 +91,7 @@ class HashCountingInt(int):
self.hash_count += 1
return int.__hash__(self)
-class TestJointOps:
+class _TestJointOps:
# Tests common to both set and frozenset
def setUp(self):
@@ -47,6 +100,7 @@ class TestJointOps:
self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
self.s = self.thetype(word)
self.d = dict.fromkeys(word)
+ super().setUp()
def test_new_or_init(self):
self.assertRaises(TypeError, self.thetype, [], 2)
@@ -261,13 +315,14 @@ class TestJointOps:
self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
def test_deepcopy(self):
- class Tracer:
- def __init__(self, value):
- self.value = value
- def __hash__(self):
- return self.value
- def __deepcopy__(self, memo=None):
- return Tracer(self.value + 1)
+ with torch._dynamo.error_on_graph_break(False):
+ class Tracer:
+ def __init__(self, value):
+ self.value = value
+ def __hash__(self):
+ return self.value
+ def __deepcopy__(self, memo=None):
+ return Tracer(self.value + 1)
t = Tracer(10)
s = self.thetype([t])
dup = copy.deepcopy(s)
@@ -279,8 +334,9 @@ class TestJointOps:
def test_gc(self):
# Create a nest of cycles to exercise overall ref count check
- class A:
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class A:
+ pass
s = set(A() for i in range(1000))
for elem in s:
elem.cycle = s
@@ -289,9 +345,10 @@ class TestJointOps:
def test_subclass_with_custom_hash(self):
# Bug #1257731
- class H(self.thetype):
- def __hash__(self):
- return int(id(self) & 0x7fffffff)
+ with torch._dynamo.error_on_graph_break(False):
+ class H(self.thetype):
+ def __hash__(self):
+ return int(id(self) & 0x7fffffff)
s=H()
f=set()
f.add(s)
@@ -342,8 +399,9 @@ class TestJointOps:
def test_container_iterator(self):
# Bug #3680: tp_traverse was not implemented for set iterator object
- class C(object):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class C(object):
+ pass
obj = C()
ref = weakref.ref(obj)
container = set([obj, 1])
@@ -355,7 +413,7 @@ class TestJointOps:
def test_free_after_iterating(self):
support.check_free_after_iterating(self, iter, self.thetype)
-class TestSet(TestJointOps, unittest.TestCase):
+class TestSet(_TestJointOps, __TestCase):
thetype = set
basetype = set
@@ -600,19 +658,20 @@ class TestSet(TestJointOps, unittest.TestCase):
self.assertRaises(ReferenceError, str, p)
def test_rich_compare(self):
- class TestRichSetCompare:
- def __gt__(self, some_set):
- self.gt_called = True
- return False
- def __lt__(self, some_set):
- self.lt_called = True
- return False
- def __ge__(self, some_set):
- self.ge_called = True
- return False
- def __le__(self, some_set):
- self.le_called = True
- return False
+ with torch._dynamo.error_on_graph_break(False):
+ class TestRichSetCompare:
+ def __gt__(self, some_set):
+ self.gt_called = True
+ return False
+ def __lt__(self, some_set):
+ self.lt_called = True
+ return False
+ def __ge__(self, some_set):
+ self.ge_called = True
+ return False
+ def __le__(self, some_set):
+ self.le_called = True
+ return False
# This first tries the builtin rich set comparison, which doesn't know
# how to handle the custom object. Upon returning NotImplemented, the
@@ -644,28 +703,31 @@ class TestSetSubclass(TestSet):
basetype = set
def test_keywords_in_subclass(self):
- class subclass(set):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(set):
+ pass
u = subclass([1, 2])
self.assertIs(type(u), subclass)
self.assertEqual(set(u), {1, 2})
with self.assertRaises(TypeError):
subclass(sequence=())
- class subclass_with_init(set):
- def __init__(self, arg, newarg=None):
- super().__init__(arg)
- self.newarg = newarg
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_init(set):
+ def __init__(self, arg, newarg=None):
+ super().__init__(arg)
+ self.newarg = newarg
u = subclass_with_init([1, 2], newarg=3)
self.assertIs(type(u), subclass_with_init)
self.assertEqual(set(u), {1, 2})
self.assertEqual(u.newarg, 3)
- class subclass_with_new(set):
- def __new__(cls, arg, newarg=None):
- self = super().__new__(cls, arg)
- self.newarg = newarg
- return self
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_new(set):
+ def __new__(cls, arg, newarg=None):
+ self = super().__new__(cls, arg)
+ self.newarg = newarg
+ return self
u = subclass_with_new([1, 2])
self.assertIs(type(u), subclass_with_new)
self.assertEqual(set(u), {1, 2})
@@ -675,7 +737,7 @@ class TestSetSubclass(TestSet):
subclass_with_new([1, 2], newarg=3)
-class TestFrozenSet(TestJointOps, unittest.TestCase):
+class TestFrozenSet(_TestJointOps, __TestCase):
thetype = frozenset
basetype = frozenset
@@ -756,27 +818,30 @@ class TestFrozenSetSubclass(TestFrozenSet):
basetype = frozenset
def test_keywords_in_subclass(self):
- class subclass(frozenset):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass(frozenset):
+ pass
u = subclass([1, 2])
self.assertIs(type(u), subclass)
self.assertEqual(set(u), {1, 2})
with self.assertRaises(TypeError):
subclass(sequence=())
- class subclass_with_init(frozenset):
- def __init__(self, arg, newarg=None):
- self.newarg = newarg
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_init(frozenset):
+ def __init__(self, arg, newarg=None):
+ self.newarg = newarg
u = subclass_with_init([1, 2], newarg=3)
self.assertIs(type(u), subclass_with_init)
self.assertEqual(set(u), {1, 2})
self.assertEqual(u.newarg, 3)
- class subclass_with_new(frozenset):
- def __new__(cls, arg, newarg=None):
- self = super().__new__(cls, arg)
- self.newarg = newarg
- return self
+ with torch._dynamo.error_on_graph_break(False):
+ class subclass_with_new(frozenset):
+ def __new__(cls, arg, newarg=None):
+ self = super().__new__(cls, arg)
+ self.newarg = newarg
+ return self
u = subclass_with_new([1, 2], newarg=3)
self.assertIs(type(u), subclass_with_new)
self.assertEqual(set(u), {1, 2})
@@ -811,10 +876,17 @@ class TestFrozenSetSubclass(TestFrozenSet):
class SetSubclassWithSlots(set):
__slots__ = ('x', 'y', '__dict__')
-class TestSetSubclassWithSlots(unittest.TestCase):
+class TestSetSubclassWithSlots(__TestCase):
thetype = SetSubclassWithSlots
- setUp = TestJointOps.setUp
- test_pickling = TestJointOps.test_pickling
+ test_pickling = _TestJointOps.test_pickling
+
+ def setUp(self):
+ self.word = word = 'simsalabim'
+ self.otherword = 'madagascar'
+ self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
+ self.s = self.thetype(word)
+ self.d = dict.fromkeys(word)
+ super().setUp()
class FrozenSetSubclassWithSlots(frozenset):
__slots__ = ('x', 'y', '__dict__')
@@ -828,7 +900,7 @@ empty_set = set()
#==============================================================================
-class TestBasicOps:
+class _TestBasicOps:
def test_repr(self):
if self.repr is not None:
@@ -934,7 +1006,7 @@ class TestBasicOps:
#------------------------------------------------------------------------------
-class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
+class TestBasicOpsEmpty(_TestBasicOps, __TestCase):
def setUp(self):
self.case = "empty set"
self.values = []
@@ -942,10 +1014,11 @@ class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
self.dup = set(self.values)
self.length = 0
self.repr = "set()"
+ super().setUp()
#------------------------------------------------------------------------------
-class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
+class TestBasicOpsSingleton(_TestBasicOps, __TestCase):
def setUp(self):
self.case = "unit set (number)"
self.values = [3]
@@ -953,6 +1026,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
self.dup = set(self.values)
self.length = 1
self.repr = "{3}"
+ super().setUp()
def test_in(self):
self.assertIn(3, self.set)
@@ -962,7 +1036,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
#------------------------------------------------------------------------------
-class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
+class TestBasicOpsTuple(_TestBasicOps, __TestCase):
def setUp(self):
self.case = "unit set (tuple)"
self.values = [(0, "zero")]
@@ -970,6 +1044,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
self.dup = set(self.values)
self.length = 1
self.repr = "{(0, 'zero')}"
+ super().setUp()
def test_in(self):
self.assertIn((0, "zero"), self.set)
@@ -979,7 +1054,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
#------------------------------------------------------------------------------
-class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
+class TestBasicOpsTriple(_TestBasicOps, __TestCase):
def setUp(self):
self.case = "triple set"
self.values = [0, "zero", operator.add]
@@ -987,36 +1062,39 @@ class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
self.dup = set(self.values)
self.length = 3
self.repr = None
+ super().setUp()
#------------------------------------------------------------------------------
-class TestBasicOpsString(TestBasicOps, unittest.TestCase):
+class TestBasicOpsString(_TestBasicOps, __TestCase):
def setUp(self):
self.case = "string set"
self.values = ["a", "b", "c"]
self.set = set(self.values)
self.dup = set(self.values)
self.length = 3
+ super().setUp()
def test_repr(self):
self.check_repr_against_values()
#------------------------------------------------------------------------------
-class TestBasicOpsBytes(TestBasicOps, unittest.TestCase):
+class TestBasicOpsBytes(_TestBasicOps, __TestCase):
def setUp(self):
self.case = "bytes set"
self.values = [b"a", b"b", b"c"]
self.set = set(self.values)
self.dup = set(self.values)
self.length = 3
+ super().setUp()
def test_repr(self):
self.check_repr_against_values()
#------------------------------------------------------------------------------
-class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
+class TestBasicOpsMixedStringBytes(_TestBasicOps, __TestCase):
def setUp(self):
self.enterContext(warnings_helper.check_warnings())
warnings.simplefilter('ignore', BytesWarning)
@@ -1025,6 +1103,7 @@ class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
self.set = set(self.values)
self.dup = set(self.values)
self.length = 4
+ super().setUp()
def test_repr(self):
self.check_repr_against_values()
@@ -1038,7 +1117,7 @@ def baditer():
def gooditer():
yield True
-class TestExceptionPropagation(unittest.TestCase):
+class TestExceptionPropagation(__TestCase):
"""SF 628246: Set constructor should not trap iterator TypeErrors"""
def test_instanceWithException(self):
@@ -1065,7 +1144,7 @@ class TestExceptionPropagation(unittest.TestCase):
#==============================================================================
-class TestSetOfSets(unittest.TestCase):
+class TestSetOfSets(__TestCase):
def test_constructor(self):
inner = frozenset([1])
outer = set([inner])
@@ -1078,9 +1157,10 @@ class TestSetOfSets(unittest.TestCase):
#==============================================================================
-class TestBinaryOps(unittest.TestCase):
+class TestBinaryOps(__TestCase):
def setUp(self):
self.set = set((2, 4, 6))
+ super().setUp()
def test_eq(self): # SF bug 643115
self.assertEqual(self.set, set({2:1,4:3,6:5}))
@@ -1151,9 +1231,10 @@ class TestBinaryOps(unittest.TestCase):
#==============================================================================
-class TestUpdateOps(unittest.TestCase):
+class TestUpdateOps(__TestCase):
def setUp(self):
self.set = set((2, 4, 6))
+ super().setUp()
def test_union_subset(self):
self.set |= set([2])
@@ -1237,10 +1318,11 @@ class TestUpdateOps(unittest.TestCase):
#==============================================================================
-class TestMutate(unittest.TestCase):
+class TestMutate(__TestCase):
def setUp(self):
self.values = ["a", "b", "c"]
self.set = set(self.values)
+ super().setUp()
def test_add_present(self):
self.set.add("c")
@@ -1311,7 +1393,7 @@ class TestMutate(unittest.TestCase):
#==============================================================================
-class TestSubsets:
+class _TestSubsets:
case2method = {"<=": "issubset",
">=": "issuperset",
@@ -1334,22 +1416,22 @@ class TestSubsets:
result = eval("x" + case + "y", locals())
self.assertEqual(result, expected)
# Test the "friendly" method-name spelling, if one exists.
- if case in TestSubsets.case2method:
- method = getattr(x, TestSubsets.case2method[case])
+ if case in _TestSubsets.case2method:
+ method = getattr(x, _TestSubsets.case2method[case])
result = method(y)
self.assertEqual(result, expected)
# Now do the same for the operands reversed.
- rcase = TestSubsets.reverse[case]
+ rcase = _TestSubsets.reverse[case]
result = eval("y" + rcase + "x", locals())
self.assertEqual(result, expected)
- if rcase in TestSubsets.case2method:
- method = getattr(y, TestSubsets.case2method[rcase])
+ if rcase in _TestSubsets.case2method:
+ method = getattr(y, _TestSubsets.case2method[rcase])
result = method(x)
self.assertEqual(result, expected)
#------------------------------------------------------------------------------
-class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
+class TestSubsetEqualEmpty(_TestSubsets, __TestCase):
left = set()
right = set()
name = "both empty"
@@ -1357,7 +1439,7 @@ class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
#------------------------------------------------------------------------------
-class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
+class TestSubsetEqualNonEmpty(_TestSubsets, __TestCase):
left = set([1, 2])
right = set([1, 2])
name = "equal pair"
@@ -1365,7 +1447,7 @@ class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
#------------------------------------------------------------------------------
-class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
+class TestSubsetEmptyNonEmpty(_TestSubsets, __TestCase):
left = set()
right = set([1, 2])
name = "one empty, one non-empty"
@@ -1373,7 +1455,7 @@ class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
#------------------------------------------------------------------------------
-class TestSubsetPartial(TestSubsets, unittest.TestCase):
+class TestSubsetPartial(_TestSubsets, __TestCase):
left = set([1])
right = set([1, 2])
name = "one a non-empty proper subset of other"
@@ -1381,7 +1463,7 @@ class TestSubsetPartial(TestSubsets, unittest.TestCase):
#------------------------------------------------------------------------------
-class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
+class TestSubsetNonOverlap(_TestSubsets, __TestCase):
left = set([1])
right = set([2])
name = "neither empty, neither contains"
@@ -1389,7 +1471,7 @@ class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
#==============================================================================
-class TestOnlySetsInBinaryOps:
+class _TestOnlySetsInBinaryOps:
def test_eq_ne(self):
# Unlike the others, this is testing that == and != *are* allowed.
@@ -1505,47 +1587,52 @@ class TestOnlySetsInBinaryOps:
#------------------------------------------------------------------------------
-class TestOnlySetsNumeric(TestOnlySetsInBinaryOps, unittest.TestCase):
+class TestOnlySetsNumeric(_TestOnlySetsInBinaryOps, __TestCase):
def setUp(self):
self.set = set((1, 2, 3))
self.other = 19
self.otherIsIterable = False
+ super().setUp()
#------------------------------------------------------------------------------
-class TestOnlySetsDict(TestOnlySetsInBinaryOps, unittest.TestCase):
+class TestOnlySetsDict(_TestOnlySetsInBinaryOps, __TestCase):
def setUp(self):
self.set = set((1, 2, 3))
self.other = {1:2, 3:4}
self.otherIsIterable = True
+ super().setUp()
#------------------------------------------------------------------------------
-class TestOnlySetsOperator(TestOnlySetsInBinaryOps, unittest.TestCase):
+class TestOnlySetsOperator(_TestOnlySetsInBinaryOps, __TestCase):
def setUp(self):
self.set = set((1, 2, 3))
self.other = operator.add
self.otherIsIterable = False
+ super().setUp()
#------------------------------------------------------------------------------
-class TestOnlySetsTuple(TestOnlySetsInBinaryOps, unittest.TestCase):
+class TestOnlySetsTuple(_TestOnlySetsInBinaryOps, __TestCase):
def setUp(self):
self.set = set((1, 2, 3))
self.other = (2, 4, 6)
self.otherIsIterable = True
+ super().setUp()
#------------------------------------------------------------------------------
-class TestOnlySetsString(TestOnlySetsInBinaryOps, unittest.TestCase):
+class TestOnlySetsString(_TestOnlySetsInBinaryOps, __TestCase):
def setUp(self):
self.set = set((1, 2, 3))
self.other = 'abc'
self.otherIsIterable = True
+ super().setUp()
#------------------------------------------------------------------------------
-class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
+class TestOnlySetsGenerator(_TestOnlySetsInBinaryOps, __TestCase):
def setUp(self):
def gen():
for i in range(0, 10, 2):
@@ -1553,10 +1640,11 @@ class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
self.set = set((1, 2, 3))
self.other = gen()
self.otherIsIterable = True
+ super().setUp()
#==============================================================================
-class TestCopying:
+class _TestCopying:
def test_copy(self):
dup = self.set.copy()
@@ -1577,40 +1665,46 @@ class TestCopying:
#------------------------------------------------------------------------------
-class TestCopyingEmpty(TestCopying, unittest.TestCase):
+class TestCopyingEmpty(_TestCopying, __TestCase):
def setUp(self):
self.set = set()
+ super().setUp()
#------------------------------------------------------------------------------
-class TestCopyingSingleton(TestCopying, unittest.TestCase):
+class TestCopyingSingleton(_TestCopying, __TestCase):
def setUp(self):
self.set = set(["hello"])
+ super().setUp()
#------------------------------------------------------------------------------
-class TestCopyingTriple(TestCopying, unittest.TestCase):
+class TestCopyingTriple(_TestCopying, __TestCase):
def setUp(self):
self.set = set(["zero", 0, None])
+ super().setUp()
#------------------------------------------------------------------------------
-class TestCopyingTuple(TestCopying, unittest.TestCase):
+class TestCopyingTuple(_TestCopying, __TestCase):
def setUp(self):
self.set = set([(1, 2)])
+ super().setUp()
#------------------------------------------------------------------------------
-class TestCopyingNested(TestCopying, unittest.TestCase):
+class TestCopyingNested(_TestCopying, __TestCase):
def setUp(self):
self.set = set([((1, 2), (3, 4))])
+ super().setUp()
#==============================================================================
-class TestIdentities(unittest.TestCase):
+class TestIdentities(__TestCase):
def setUp(self):
self.a = set('abracadabra')
self.b = set('alacazam')
+ super().setUp()
def test_binopsVsSubsets(self):
a, b = self.a, self.b
@@ -1727,7 +1821,7 @@ def L(seqn):
'Test multiple tiers of iterators'
return chain(map(lambda x:x, R(Ig(G(seqn)))))
-class TestVariousIteratorArgs(unittest.TestCase):
+class TestVariousIteratorArgs(__TestCase):
def test_constructor(self):
for cons in (set, frozenset):
@@ -1785,7 +1879,7 @@ class bad_dict_clear:
def __hash__(self):
return 0
-class TestWeirdBugs(unittest.TestCase):
+class TestWeirdBugs(__TestCase):
def test_8420_set_merge(self):
# This used to segfault
global be_bad, set2, dict2
@@ -1813,12 +1907,13 @@ class TestWeirdBugs(unittest.TestCase):
list(si)
def test_merge_and_mutate(self):
- class X:
- def __hash__(self):
- return hash(0)
- def __eq__(self, o):
- other.clear()
- return False
+ with torch._dynamo.error_on_graph_break(False):
+ class X:
+ def __hash__(self):
+ return hash(0)
+ def __eq__(self, o):
+ other.clear()
+ return False
other = set()
other = {X() for i in range(10)}
@@ -1826,24 +1921,25 @@ class TestWeirdBugs(unittest.TestCase):
s.update(other)
-class TestOperationsMutating:
+class _TestOperationsMutating:
"""Regression test for bpo-46615"""
constructor1 = None
constructor2 = None
def make_sets_of_bad_objects(self):
- class Bad:
- def __eq__(self, other):
- if not enabled:
- return False
- if randrange(20) == 0:
- set1.clear()
- if randrange(20) == 0:
- set2.clear()
- return bool(randrange(2))
- def __hash__(self):
- return randrange(2)
+ with torch._dynamo.error_on_graph_break(False):
+ class Bad:
+ def __eq__(self, other):
+ if not enabled:
+ return False
+ if randrange(20) == 0:
+ set1.clear()
+ if randrange(20) == 0:
+ set2.clear()
+ return bool(randrange(2))
+ def __hash__(self):
+ return randrange(2)
# Don't behave poorly during construction.
enabled = False
set1 = self.constructor1(Bad() for _ in range(randrange(50)))
@@ -1862,7 +1958,7 @@ class TestOperationsMutating:
self.assertIn("changed size during iteration", str(e))
-class TestBinaryOpsMutating(TestOperationsMutating):
+class _TestBinaryOpsMutating(_TestOperationsMutating):
def test_eq_with_mutation(self):
self.check_set_op_does_not_crash(lambda a, b: a == b)
@@ -1933,24 +2029,24 @@ class TestBinaryOpsMutating(TestOperationsMutating):
self.check_set_op_does_not_crash(f3)
-class TestBinaryOpsMutating_Set_Set(TestBinaryOpsMutating, unittest.TestCase):
+class TestBinaryOpsMutating_Set_Set(_TestBinaryOpsMutating, __TestCase):
constructor1 = set
constructor2 = set
-class TestBinaryOpsMutating_Subclass_Subclass(TestBinaryOpsMutating, unittest.TestCase):
+class TestBinaryOpsMutating_Subclass_Subclass(_TestBinaryOpsMutating, __TestCase):
constructor1 = SetSubclass
constructor2 = SetSubclass
-class TestBinaryOpsMutating_Set_Subclass(TestBinaryOpsMutating, unittest.TestCase):
+class TestBinaryOpsMutating_Set_Subclass(_TestBinaryOpsMutating, __TestCase):
constructor1 = set
constructor2 = SetSubclass
-class TestBinaryOpsMutating_Subclass_Set(TestBinaryOpsMutating, unittest.TestCase):
+class TestBinaryOpsMutating_Subclass_Set(_TestBinaryOpsMutating, __TestCase):
constructor1 = SetSubclass
constructor2 = set
-class TestMethodsMutating(TestOperationsMutating):
+class _TestMethodsMutating(_TestOperationsMutating):
def test_issubset_with_mutation(self):
self.check_set_op_does_not_crash(set.issubset)
@@ -1986,27 +2082,27 @@ class TestMethodsMutating(TestOperationsMutating):
self.check_set_op_does_not_crash(set.update)
-class TestMethodsMutating_Set_Set(TestMethodsMutating, unittest.TestCase):
+class TestMethodsMutating_Set_Set(_TestMethodsMutating, __TestCase):
constructor1 = set
constructor2 = set
-class TestMethodsMutating_Subclass_Subclass(TestMethodsMutating, unittest.TestCase):
+class TestMethodsMutating_Subclass_Subclass(_TestMethodsMutating, __TestCase):
constructor1 = SetSubclass
constructor2 = SetSubclass
-class TestMethodsMutating_Set_Subclass(TestMethodsMutating, unittest.TestCase):
+class TestMethodsMutating_Set_Subclass(_TestMethodsMutating, __TestCase):
constructor1 = set
constructor2 = SetSubclass
-class TestMethodsMutating_Subclass_Set(TestMethodsMutating, unittest.TestCase):
+class TestMethodsMutating_Subclass_Set(_TestMethodsMutating, __TestCase):
constructor1 = SetSubclass
constructor2 = set
-class TestMethodsMutating_Set_Dict(TestMethodsMutating, unittest.TestCase):
+class TestMethodsMutating_Set_Dict(_TestMethodsMutating, __TestCase):
constructor1 = set
constructor2 = dict.fromkeys
-class TestMethodsMutating_Set_List(TestMethodsMutating, unittest.TestCase):
+class TestMethodsMutating_Set_List(_TestMethodsMutating, __TestCase):
constructor1 = set
constructor2 = list
@@ -2068,7 +2164,7 @@ def faces(G):
return f
-class TestGraphs(unittest.TestCase):
+class TestGraphs(__TestCase):
def test_cube(self):
@@ -2118,4 +2214,4 @@ class TestGraphs(unittest.TestCase):
#==============================================================================
if __name__ == "__main__":
- unittest.main()
+ run_tests()