mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Fix and simplify hanlding of Set.update method (#141286)
The old implementation of `SetVariable.call_method("update", ...)` was
incorrectly becacuse it wouldn't handle iterable inputs. This patches
removes the input type restriction altogether, and implements the method
as a polyfill (like how most of the other set methods are handled).
Fixes #141283.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141286
Approved by: https://github.com/anijain2305
This commit is contained in:
parent
5d7c3701e4
commit
583484b726
|
|
@ -4854,6 +4854,19 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||||
self.assertEqual(opt_fn(x), x)
|
self.assertEqual(opt_fn(x), x)
|
||||||
self.assertEqual(cnts.op_count, 1)
|
self.assertEqual(cnts.op_count, 1)
|
||||||
|
|
||||||
|
def test_set_update(self):
|
||||||
|
@torch.compile(backend="eager", fullgraph=True)
|
||||||
|
def run(x, int_set, int_list):
|
||||||
|
int_set.update(map(int, int_list))
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
int_set = set()
|
||||||
|
int_list = [1, 2, 1]
|
||||||
|
res = run(torch.ones(1), int_set, int_list)
|
||||||
|
self.assertTrue(same(res, torch.ones(1) + 1))
|
||||||
|
self.assertEqual(int_set, set([1, 2]))
|
||||||
|
self.assertEqual(int_list, [1, 2, 1])
|
||||||
|
|
||||||
def test_frozenset_torch_func_contains(self):
|
def test_frozenset_torch_func_contains(self):
|
||||||
funcs = frozenset([torch.add])
|
funcs = frozenset([torch.add])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -101,12 +101,17 @@ def set_intersection(set1, set2):
|
||||||
|
|
||||||
def set_union(set1, set2):
|
def set_union(set1, set2):
|
||||||
union_set = set1.copy()
|
union_set = set1.copy()
|
||||||
for x in set2:
|
set_update(union_set, set2)
|
||||||
if x not in union_set:
|
|
||||||
union_set.add(x)
|
|
||||||
return union_set
|
return union_set
|
||||||
|
|
||||||
|
|
||||||
|
def set_update(set1, set2):
|
||||||
|
for x in set2:
|
||||||
|
if x not in set1:
|
||||||
|
set1.add(x)
|
||||||
|
return set1
|
||||||
|
|
||||||
|
|
||||||
def set_difference(set1, set2):
|
def set_difference(set1, set2):
|
||||||
difference_set = set()
|
difference_set = set()
|
||||||
for x in set1:
|
for x in set1:
|
||||||
|
|
|
||||||
|
|
@ -497,8 +497,6 @@ class SetVariable(ConstDictVariable):
|
||||||
args: List[VariableTracker],
|
args: List[VariableTracker],
|
||||||
kwargs: Dict[str, VariableTracker],
|
kwargs: Dict[str, VariableTracker],
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
from . import ListVariable, TupleVariable
|
|
||||||
|
|
||||||
# We foward the calls to the dictionary model
|
# We foward the calls to the dictionary model
|
||||||
if name == "add":
|
if name == "add":
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
|
|
@ -536,24 +534,12 @@ class SetVariable(ConstDictVariable):
|
||||||
return variables.UserFunctionVariable(
|
return variables.UserFunctionVariable(
|
||||||
polyfills.set_difference
|
polyfills.set_difference
|
||||||
).call_function(tx, [self, args[0]], {})
|
).call_function(tx, [self, args[0]], {})
|
||||||
elif (
|
elif name == "update" and len(args) == 1 and self.is_mutable():
|
||||||
name == "update"
|
assert not kwargs
|
||||||
and len(args) == 1
|
assert len(args) == 1
|
||||||
and isinstance(
|
return variables.UserFunctionVariable(polyfills.set_update).call_function(
|
||||||
args[0],
|
tx, [self, args[0]], {}
|
||||||
(
|
|
||||||
SetVariable,
|
|
||||||
ListVariable,
|
|
||||||
TupleVariable,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
and self.is_mutable()
|
|
||||||
):
|
|
||||||
if isinstance(args[0], (ListVariable, TupleVariable)):
|
|
||||||
arg = SetVariable(args[0].unpack_var_sequence(tx))
|
|
||||||
else:
|
|
||||||
arg = args[0]
|
|
||||||
return super().call_method(tx, "update", (arg,), kwargs)
|
|
||||||
elif name == "remove":
|
elif name == "remove":
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
assert len(args) == 1
|
assert len(args) == 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user