mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Fix and simplify hanlding of Set.update method (#141286)
The old implementation of `SetVariable.call_method("update", ...)` was
incorrectly becacuse it wouldn't handle iterable inputs. This patches
removes the input type restriction altogether, and implements the method
as a polyfill (like how most of the other set methods are handled).
Fixes #141283.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141286
Approved by: https://github.com/anijain2305
This commit is contained in:
parent
5d7c3701e4
commit
583484b726
|
|
@ -4854,6 +4854,19 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
self.assertEqual(opt_fn(x), x)
|
||||
self.assertEqual(cnts.op_count, 1)
|
||||
|
||||
def test_set_update(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def run(x, int_set, int_list):
|
||||
int_set.update(map(int, int_list))
|
||||
return x + 1
|
||||
|
||||
int_set = set()
|
||||
int_list = [1, 2, 1]
|
||||
res = run(torch.ones(1), int_set, int_list)
|
||||
self.assertTrue(same(res, torch.ones(1) + 1))
|
||||
self.assertEqual(int_set, set([1, 2]))
|
||||
self.assertEqual(int_list, [1, 2, 1])
|
||||
|
||||
def test_frozenset_torch_func_contains(self):
|
||||
funcs = frozenset([torch.add])
|
||||
|
||||
|
|
|
|||
|
|
@ -101,12 +101,17 @@ def set_intersection(set1, set2):
|
|||
|
||||
def set_union(set1, set2):
|
||||
union_set = set1.copy()
|
||||
for x in set2:
|
||||
if x not in union_set:
|
||||
union_set.add(x)
|
||||
set_update(union_set, set2)
|
||||
return union_set
|
||||
|
||||
|
||||
def set_update(set1, set2):
|
||||
for x in set2:
|
||||
if x not in set1:
|
||||
set1.add(x)
|
||||
return set1
|
||||
|
||||
|
||||
def set_difference(set1, set2):
|
||||
difference_set = set()
|
||||
for x in set1:
|
||||
|
|
|
|||
|
|
@ -497,8 +497,6 @@ class SetVariable(ConstDictVariable):
|
|||
args: List[VariableTracker],
|
||||
kwargs: Dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
from . import ListVariable, TupleVariable
|
||||
|
||||
# We foward the calls to the dictionary model
|
||||
if name == "add":
|
||||
assert not kwargs
|
||||
|
|
@ -536,24 +534,12 @@ class SetVariable(ConstDictVariable):
|
|||
return variables.UserFunctionVariable(
|
||||
polyfills.set_difference
|
||||
).call_function(tx, [self, args[0]], {})
|
||||
elif (
|
||||
name == "update"
|
||||
and len(args) == 1
|
||||
and isinstance(
|
||||
args[0],
|
||||
(
|
||||
SetVariable,
|
||||
ListVariable,
|
||||
TupleVariable,
|
||||
),
|
||||
elif name == "update" and len(args) == 1 and self.is_mutable():
|
||||
assert not kwargs
|
||||
assert len(args) == 1
|
||||
return variables.UserFunctionVariable(polyfills.set_update).call_function(
|
||||
tx, [self, args[0]], {}
|
||||
)
|
||||
and self.is_mutable()
|
||||
):
|
||||
if isinstance(args[0], (ListVariable, TupleVariable)):
|
||||
arg = SetVariable(args[0].unpack_var_sequence(tx))
|
||||
else:
|
||||
arg = args[0]
|
||||
return super().call_method(tx, "update", (arg,), kwargs)
|
||||
elif name == "remove":
|
||||
assert not kwargs
|
||||
assert len(args) == 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user