[dict] Implement dict.__eq__ and dict.__ne__ (#154942)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154942
Approved by: https://github.com/zou3519
ghstack dependencies: #154003, #154793, #154794
This commit is contained in:
Guilherme Leobas 2025-07-08 15:27:55 -03:00 committed by PyTorch MergeBot
parent ba8d19ec02
commit ba71eb496b
5 changed files with 78 additions and 8 deletions

View File

@ -1086,6 +1086,56 @@ class DictGuardTests(LoggingTestCase):
record = self.getRecord(records, "d") record = self.getRecord(records, "d")
self.assertIn( self.assertIn(
"""d[3] == 4""", """d[3] == 4""",
munge_exc(record),
)
@make_logging_test(recompiles=True)
def test_cmp_eq(self, records):
@torch.compile(backend="eager", fullgraph=True)
def fn(x, d1, d2):
if d1 == d2:
return x.sin()
return x.cos()
x = torch.tensor(1.0)
d1 = self.thetype({1: 2, 3: 4})
d2 = self.thetype({1: 2, 5: 6})
y = fn(x, d1, d2)
# sanity check
self.assertEqual(len(records), 0)
self.assertEqual(y, x.cos())
y = fn(x, d1, d1)
self.assertEqual(len(records), 1)
self.assertEqual(y, x.sin())
record = self.getRecord(records, "d2")
self.assertIn(
"""list(dict.keys(d2))""",
munge_exc(record.getMessage()),
)
@make_logging_test(recompiles=True)
def test_cmp_ne(self, records):
@torch.compile(backend="eager", fullgraph=True)
def fn(x, d1, d2):
if d1 == d2:
return x.sin()
return x.cos()
x = torch.tensor(1.0)
d1 = self.thetype({1: 2, 3: 4})
d2 = self.thetype({1: 2, 5: 6})
y = fn(x, d1, d2)
# sanity check
self.assertEqual(len(records), 0)
self.assertEqual(y, x.cos())
y = fn(x, d1, d1)
self.assertEqual(len(records), 1)
self.assertEqual(y, x.sin())
record = self.getRecord(records, "d2")
self.assertIn(
"""list(dict.keys(d2))""",
munge_exc(record.getMessage()), munge_exc(record.getMessage()),
) )
@ -1363,14 +1413,6 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase):
class DictSubclassMethodsTests(DictMethodsTests): class DictSubclassMethodsTests(DictMethodsTests):
thetype = SimpleDict thetype = SimpleDict
@unittest.expectedFailure
def test_cmp_eq(self):
return super().test_cmp_eq()
@unittest.expectedFailure
def test_cmp_ne(self):
return super().test_cmp_ne()
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -9,6 +9,7 @@ Python polyfills for common builtins.
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import types import types
from collections import OrderedDict
from collections.abc import Hashable, Iterable, MutableMapping, Sequence from collections.abc import Hashable, Iterable, MutableMapping, Sequence
from itertools import repeat as _repeat from itertools import repeat as _repeat
from typing import Any, Callable, TYPE_CHECKING from typing import Any, Callable, TYPE_CHECKING
@ -101,6 +102,20 @@ def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequenc
return op(len(left), len(right)) return op(len(left), len(right))
def dict___eq__(d, other):
if (len(d) != len(other)) or (d.keys() != other.keys()):
return False
if all(isinstance(a, OrderedDict) for a in (d, other)):
return list(d.items()) == list(other.items())
for k, v in d.items():
if v != other[k]:
return False
return True
def set_symmetric_difference(set1, set2): def set_symmetric_difference(set1, set2):
symmetric_difference_set = set() symmetric_difference_set = set()
for x in set1: for x in set1:

View File

@ -606,6 +606,19 @@ class ConstDictVariable(VariableTracker):
self.items.pop(key) self.items.pop(key)
self.items[key] = val self.items[key] = val
return ConstantVariable.create(None) return ConstantVariable.create(None)
elif name == "__eq__" and istype(
self, ConstDictVariable
): # don't let Set use this function
if len(args) != 1:
raise_args_mismatch(tx, name)
return variables.UserFunctionVariable(polyfills.dict___eq__).call_function(
tx, [self, args[0]], {}
)
elif name == "__ne__":
return ConstantVariable.create(
not self.call_method(tx, "__eq__", args, kwargs).value
)
elif name == "__or__": elif name == "__or__":
assert len(args) == 1 assert len(args) == 1
# Dicts can only be unioned with other dicts or subclasses of dicts. # Dicts can only be unioned with other dicts or subclasses of dicts.