[dynamo] Fix MATCH_KEYS for dict pattern matching (#165956)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165956
Approved by: https://github.com/guilhermeleobas, https://github.com/cyyever
This commit is contained in:
Rob Timpe 2025-10-21 20:14:25 +00:00 committed by PyTorch MergeBot
parent 715449ca76
commit 550e3e6efb
2 changed files with 24 additions and 25 deletions

View File

@ -2302,30 +2302,27 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
return augment(x)
# # This is to test the new syntax for pattern matching
# # ("match ... case ...") added on python 3.10.
# # Uncomment these test cases if you run on 3.10+
# @make_test
# def test_match_sequence(a):
# point = (5, 8)
# match point:
# case (0, 0):
# return a
# case (0, y):
# return a - y
# case (x, 0):
# return a + x
# case (x, y):
# return a + x - y
@make_test
def test_match_sequence(a):
point = (5, 8)
match point:
case (0, 0):
return a
case (0, y):
return a - y
case (x, 0):
return a + x
case (x, y):
return a + x - y
# @make_test
# def test_match_mapping_and_match_keys(x):
# param = {"a": 0.5}
# match param:
# case {"a": param}:
# return x * param
# case {"b": param}:
# return x / param
@make_test
def test_match_mapping_and_match_keys(x):
param = {"a": 0.5}
match param:
case {"a": param}:
return x * param
case {"b": param}:
return x / param
def test_math_radians(self):
def func(x, a):

View File

@ -3584,11 +3584,13 @@ class InstructionTranslatorBase(
def MATCH_KEYS(self, inst: Instruction) -> None:
tos = self.stack[-1]
assert isinstance(tos, TupleVariable)
keys = tos.unpack_var_sequence(self)
tos1 = self.stack[-2]
assert isinstance(tos1, ConstDictVariable)
if all(k in tos1 for k in tos): # type: ignore[attr-defined]
self.push(TupleVariable([tos1.getitem_const(self, k) for k in tos])) # type: ignore[attr-defined,arg-type]
if all(k in tos1 for k in keys): # type: ignore[attr-defined]
self.push(TupleVariable([tos1.getitem_const(self, k) for k in keys])) # type: ignore[attr-defined,arg-type]
if sys.version_info < (3, 11):
self.push(ConstantVariable.create(True))
else: