mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] Fix dict.get with no default (#115048)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115048 Approved by: https://github.com/eellison, https://github.com/oulgen ghstack dependencies: #114830, #115047
This commit is contained in:
parent
f6b6fad136
commit
fe690f430a
|
|
@ -118,7 +118,7 @@ hf_T5,pass,0
|
|||
|
||||
|
||||
|
||||
hf_T5_generate,pass,19
|
||||
hf_T5_generate,pass,18
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
|
@ -118,7 +118,7 @@ hf_T5,pass,0
|
|||
|
||||
|
||||
|
||||
hf_T5_generate,pass,19
|
||||
hf_T5_generate,pass,18
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
|
@ -118,7 +118,7 @@ hf_T5,pass,0
|
|||
|
||||
|
||||
|
||||
hf_T5_generate,pass,19
|
||||
hf_T5_generate,pass,18
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
|
@ -689,6 +689,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
@make_test
|
||||
def test_dict_ops(a, b):
|
||||
tmp = {"a": a + 1, "b": b + 2}
|
||||
assert tmp.get("zzz") is None
|
||||
v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4)
|
||||
tmp.update({"d": 3})
|
||||
tmp["c"] = v + tmp["d"]
|
||||
|
|
|
|||
|
|
@ -147,11 +147,20 @@ class ConstDictVariable(VariableTracker):
|
|||
elif (
|
||||
name in ("pop", "get")
|
||||
and len(args) == 2
|
||||
and not kwargs
|
||||
and ConstDictVariable.is_valid_key(args[0])
|
||||
and ConstDictVariable.get_key(args[0]) not in self.items
|
||||
):
|
||||
# missing item, return the default value
|
||||
return args[1]
|
||||
elif (
|
||||
name == "get"
|
||||
and len(args) == 1
|
||||
and not kwargs
|
||||
and ConstDictVariable.is_valid_key(args[0])
|
||||
and ConstDictVariable.get_key(args[0]) not in self.items
|
||||
):
|
||||
return ConstantVariable(None)
|
||||
elif (
|
||||
name == "pop"
|
||||
and args
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user