[dynamo] Fix and simplify hanlding of Set.update method (#141286)

The old implementation of `SetVariable.call_method("update", ...)` was
incorrectly becacuse it wouldn't handle iterable inputs. This patches
removes the input type restriction altogether, and implements the method
as a polyfill (like how most of the other set methods are handled).

Fixes #141283.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141286
Approved by: https://github.com/anijain2305
This commit is contained in:
Ryan Guo 2024-11-25 11:41:15 -08:00 committed by PyTorch MergeBot
parent 5d7c3701e4
commit 583484b726
3 changed files with 26 additions and 22 deletions

View File

@ -4854,6 +4854,19 @@ utils_device.CURRENT_DEVICE == None""".split(
self.assertEqual(opt_fn(x), x) self.assertEqual(opt_fn(x), x)
self.assertEqual(cnts.op_count, 1) self.assertEqual(cnts.op_count, 1)
def test_set_update(self):
@torch.compile(backend="eager", fullgraph=True)
def run(x, int_set, int_list):
int_set.update(map(int, int_list))
return x + 1
int_set = set()
int_list = [1, 2, 1]
res = run(torch.ones(1), int_set, int_list)
self.assertTrue(same(res, torch.ones(1) + 1))
self.assertEqual(int_set, set([1, 2]))
self.assertEqual(int_list, [1, 2, 1])
def test_frozenset_torch_func_contains(self): def test_frozenset_torch_func_contains(self):
funcs = frozenset([torch.add]) funcs = frozenset([torch.add])

View File

@ -101,12 +101,17 @@ def set_intersection(set1, set2):
def set_union(set1, set2): def set_union(set1, set2):
union_set = set1.copy() union_set = set1.copy()
for x in set2: set_update(union_set, set2)
if x not in union_set:
union_set.add(x)
return union_set return union_set
def set_update(set1, set2):
for x in set2:
if x not in set1:
set1.add(x)
return set1
def set_difference(set1, set2): def set_difference(set1, set2):
difference_set = set() difference_set = set()
for x in set1: for x in set1:

View File

@ -497,8 +497,6 @@ class SetVariable(ConstDictVariable):
args: List[VariableTracker], args: List[VariableTracker],
kwargs: Dict[str, VariableTracker], kwargs: Dict[str, VariableTracker],
) -> "VariableTracker": ) -> "VariableTracker":
from . import ListVariable, TupleVariable
# We foward the calls to the dictionary model # We foward the calls to the dictionary model
if name == "add": if name == "add":
assert not kwargs assert not kwargs
@ -536,24 +534,12 @@ class SetVariable(ConstDictVariable):
return variables.UserFunctionVariable( return variables.UserFunctionVariable(
polyfills.set_difference polyfills.set_difference
).call_function(tx, [self, args[0]], {}) ).call_function(tx, [self, args[0]], {})
elif ( elif name == "update" and len(args) == 1 and self.is_mutable():
name == "update" assert not kwargs
and len(args) == 1 assert len(args) == 1
and isinstance( return variables.UserFunctionVariable(polyfills.set_update).call_function(
args[0], tx, [self, args[0]], {}
(
SetVariable,
ListVariable,
TupleVariable,
),
) )
and self.is_mutable()
):
if isinstance(args[0], (ListVariable, TupleVariable)):
arg = SetVariable(args[0].unpack_var_sequence(tx))
else:
arg = args[0]
return super().call_method(tx, "update", (arg,), kwargs)
elif name == "remove": elif name == "remove":
assert not kwargs assert not kwargs
assert len(args) == 1 assert len(args) == 1