diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index 34e2e70a2a8..e532b136461 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -118,7 +118,7 @@ hf_T5,pass,0 -hf_T5_generate,pass,19 +hf_T5_generate,pass,18 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index 34e2e70a2a8..e532b136461 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -118,7 +118,7 @@ hf_T5,pass,0 -hf_T5_generate,pass,19 +hf_T5_generate,pass,18 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index afd1e3a99a2..5f61b5b4a85 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -118,7 +118,7 @@ hf_T5,pass,0 -hf_T5_generate,pass,19 +hf_T5_generate,pass,18 diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 1940351d529..f9a7001af94 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -689,6 +689,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_dict_ops(a, b): tmp = {"a": a + 1, "b": b + 2} + assert tmp.get("zzz") is None v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4) tmp.update({"d": 3}) tmp["c"] = v + tmp["d"] diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 1113314ac05..05e65013a14 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -147,11 +147,20 @@ class ConstDictVariable(VariableTracker): elif ( name in ("pop", "get") and len(args) == 2 + and not kwargs and ConstDictVariable.is_valid_key(args[0]) and ConstDictVariable.get_key(args[0]) not in self.items ): # missing item, return the default value return args[1] + elif ( + name == "get" + and len(args) == 1 + and not kwargs + and ConstDictVariable.is_valid_key(args[0]) + and ConstDictVariable.get_key(args[0]) not in self.items + ): + return ConstantVariable(None) elif ( name == "pop" and args