mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Set] Add set.symmetric_difference(_update) (#152901)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152901 Approved by: https://github.com/anijain2305 ghstack dependencies: #150792, #152987, #152988, #152904
This commit is contained in:
parent
fe51ce62ca
commit
5926b7a38f
|
|
@ -1755,7 +1755,9 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
y = a - b
|
||||
return x, y
|
||||
|
||||
@parametrize("fn_name", ["add"])
|
||||
@parametrize(
|
||||
"fn_name", ["add", "symmetric_difference", "symmetric_difference_update"]
|
||||
)
|
||||
def test_set_raise_TypeError(self, fn_name):
|
||||
@make_test
|
||||
def fn(a, b):
|
||||
|
|
@ -1783,6 +1785,36 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
y = a - b
|
||||
return x, y
|
||||
|
||||
@make_test
|
||||
def test_set_symmetric_difference(a, b):
|
||||
set1 = {"apple", "banana", "cherry"}
|
||||
set2 = {"google", "microsoft", "apple"}
|
||||
symmetric_diff_set = set1.difference(set2)
|
||||
if "apple" in symmetric_diff_set:
|
||||
x = a + b
|
||||
else:
|
||||
x = a - b
|
||||
if "banana" in symmetric_diff_set:
|
||||
y = a + b
|
||||
else:
|
||||
y = a - b
|
||||
return x, y
|
||||
|
||||
@make_test
|
||||
def test_set_symmetric_difference_update(a, b):
|
||||
set1 = {"apple", "banana", "cherry"}
|
||||
set2 = {"google", "microsoft", "apple"}
|
||||
set1.difference(set2)
|
||||
if "apple" in set1:
|
||||
x = a + b
|
||||
else:
|
||||
x = a - b
|
||||
if "banana" in set1:
|
||||
y = a + b
|
||||
else:
|
||||
y = a - b
|
||||
return x, y
|
||||
|
||||
def test_set_keys_view(self):
|
||||
from collections.abc import KeysView
|
||||
|
||||
|
|
|
|||
|
|
@ -101,6 +101,23 @@ def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequenc
|
|||
return op(len(left), len(right))
|
||||
|
||||
|
||||
def set_symmetric_difference(set1, set2):
|
||||
symmetric_difference_set = set()
|
||||
for x in set1:
|
||||
if x not in set2:
|
||||
symmetric_difference_set.add(x)
|
||||
for x in set2:
|
||||
if x not in set1:
|
||||
symmetric_difference_set.add(x)
|
||||
return symmetric_difference_set
|
||||
|
||||
|
||||
def set_symmetric_difference_update(set1, set2):
|
||||
result = set1.symmetric_difference(set2)
|
||||
set1.clear()
|
||||
set1.update(result)
|
||||
|
||||
|
||||
def set_isdisjoint(set1, set2):
|
||||
for x in set1:
|
||||
if x in set2:
|
||||
|
|
|
|||
|
|
@ -790,6 +790,20 @@ class SetVariable(ConstDictVariable):
|
|||
return variables.UserFunctionVariable(
|
||||
polyfills.set_difference
|
||||
).call_function(tx, [self, args[0]], {})
|
||||
elif name == "symmetric_difference":
|
||||
if len(args) != 1:
|
||||
raise_args_mismatch(tx, name)
|
||||
assert not kwargs
|
||||
return variables.UserFunctionVariable(
|
||||
polyfills.set_symmetric_difference
|
||||
).call_function(tx, [self, *args], {})
|
||||
elif name == "symmetric_difference_update":
|
||||
if len(args) != 1:
|
||||
raise_args_mismatch(tx, name)
|
||||
assert not kwargs
|
||||
return variables.UserFunctionVariable(
|
||||
polyfills.set_symmetric_difference_update
|
||||
).call_function(tx, [self, *args], {})
|
||||
elif name == "update" and len(args) == 1 and self.is_mutable():
|
||||
assert not kwargs
|
||||
assert len(args) == 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user