[Set] Add set.symmetric_difference(_update) (#152901)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152901
Approved by: https://github.com/anijain2305
ghstack dependencies: #150792, #152987, #152988, #152904
This commit is contained in:
Guilherme Leobas 2025-05-15 22:47:15 -03:00 committed by PyTorch MergeBot
parent fe51ce62ca
commit 5926b7a38f
11 changed files with 64 additions and 1 deletions

View File

@ -1755,7 +1755,9 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
y = a - b
return x, y
@parametrize("fn_name", ["add"])
@parametrize(
"fn_name", ["add", "symmetric_difference", "symmetric_difference_update"]
)
def test_set_raise_TypeError(self, fn_name):
@make_test
def fn(a, b):
@ -1783,6 +1785,36 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
y = a - b
return x, y
@make_test
def test_set_symmetric_difference(a, b):
set1 = {"apple", "banana", "cherry"}
set2 = {"google", "microsoft", "apple"}
symmetric_diff_set = set1.difference(set2)
if "apple" in symmetric_diff_set:
x = a + b
else:
x = a - b
if "banana" in symmetric_diff_set:
y = a + b
else:
y = a - b
return x, y
@make_test
def test_set_symmetric_difference_update(a, b):
set1 = {"apple", "banana", "cherry"}
set2 = {"google", "microsoft", "apple"}
set1.difference(set2)
if "apple" in set1:
x = a + b
else:
x = a - b
if "banana" in set1:
y = a + b
else:
y = a - b
return x, y
def test_set_keys_view(self):
from collections.abc import KeysView

View File

@ -101,6 +101,23 @@ def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequenc
return op(len(left), len(right))
def set_symmetric_difference(set1, set2):
symmetric_difference_set = set()
for x in set1:
if x not in set2:
symmetric_difference_set.add(x)
for x in set2:
if x not in set1:
symmetric_difference_set.add(x)
return symmetric_difference_set
def set_symmetric_difference_update(set1, set2):
result = set1.symmetric_difference(set2)
set1.clear()
set1.update(result)
def set_isdisjoint(set1, set2):
for x in set1:
if x in set2:

View File

@ -790,6 +790,20 @@ class SetVariable(ConstDictVariable):
return variables.UserFunctionVariable(
polyfills.set_difference
).call_function(tx, [self, args[0]], {})
elif name == "symmetric_difference":
if len(args) != 1:
raise_args_mismatch(tx, name)
assert not kwargs
return variables.UserFunctionVariable(
polyfills.set_symmetric_difference
).call_function(tx, [self, *args], {})
elif name == "symmetric_difference_update":
if len(args) != 1:
raise_args_mismatch(tx, name)
assert not kwargs
return variables.UserFunctionVariable(
polyfills.set_symmetric_difference_update
).call_function(tx, [self, *args], {})
elif name == "update" and len(args) == 1 and self.is_mutable():
assert not kwargs
assert len(args) == 1