[dynamo] Fix dict.get with no default (#115048)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115048
Approved by: https://github.com/eellison, https://github.com/oulgen
ghstack dependencies: #114830, #115047
This commit is contained in:
Jason Ansel 2023-12-04 08:46:30 -08:00 committed by PyTorch MergeBot
parent f6b6fad136
commit fe690f430a
5 changed files with 13 additions and 3 deletions

View File

@ -118,7 +118,7 @@ hf_T5,pass,0
hf_T5_generate,pass,19
hf_T5_generate,pass,18

1 name accuracy graph_breaks
118
119
120
121
122
123
124

View File

@ -118,7 +118,7 @@ hf_T5,pass,0
hf_T5_generate,pass,19
hf_T5_generate,pass,18

1 name accuracy graph_breaks
118
119
120
121
122
123
124

View File

@ -118,7 +118,7 @@ hf_T5,pass,0
hf_T5_generate,pass,19
hf_T5_generate,pass,18

1 name accuracy graph_breaks
118
119
120
121
122
123
124

View File

@ -689,6 +689,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
@make_test
def test_dict_ops(a, b):
tmp = {"a": a + 1, "b": b + 2}
assert tmp.get("zzz") is None
v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4)
tmp.update({"d": 3})
tmp["c"] = v + tmp["d"]

View File

@ -147,11 +147,20 @@ class ConstDictVariable(VariableTracker):
elif (
name in ("pop", "get")
and len(args) == 2
and not kwargs
and ConstDictVariable.is_valid_key(args[0])
and ConstDictVariable.get_key(args[0]) not in self.items
):
# missing item, return the default value
return args[1]
elif (
name == "get"
and len(args) == 1
and not kwargs
and ConstDictVariable.is_valid_key(args[0])
and ConstDictVariable.get_key(args[0]) not in self.items
):
return ConstantVariable(None)
elif (
name == "pop"
and args