mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[dynamo] Fix MATCH_KEYS for dict pattern matching (#165956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165956 Approved by: https://github.com/guilhermeleobas, https://github.com/cyyever
This commit is contained in:
parent
715449ca76
commit
550e3e6efb
|
|
@ -2302,30 +2302,27 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
return augment(x)
|
||||
|
||||
# # This is to test the new syntax for pattern matching
|
||||
# # ("match ... case ...") added on python 3.10.
|
||||
# # Uncomment these test cases if you run on 3.10+
|
||||
# @make_test
|
||||
# def test_match_sequence(a):
|
||||
# point = (5, 8)
|
||||
# match point:
|
||||
# case (0, 0):
|
||||
# return a
|
||||
# case (0, y):
|
||||
# return a - y
|
||||
# case (x, 0):
|
||||
# return a + x
|
||||
# case (x, y):
|
||||
# return a + x - y
|
||||
@make_test
|
||||
def test_match_sequence(a):
|
||||
point = (5, 8)
|
||||
match point:
|
||||
case (0, 0):
|
||||
return a
|
||||
case (0, y):
|
||||
return a - y
|
||||
case (x, 0):
|
||||
return a + x
|
||||
case (x, y):
|
||||
return a + x - y
|
||||
|
||||
# @make_test
|
||||
# def test_match_mapping_and_match_keys(x):
|
||||
# param = {"a": 0.5}
|
||||
# match param:
|
||||
# case {"a": param}:
|
||||
# return x * param
|
||||
# case {"b": param}:
|
||||
# return x / param
|
||||
@make_test
|
||||
def test_match_mapping_and_match_keys(x):
|
||||
param = {"a": 0.5}
|
||||
match param:
|
||||
case {"a": param}:
|
||||
return x * param
|
||||
case {"b": param}:
|
||||
return x / param
|
||||
|
||||
def test_math_radians(self):
|
||||
def func(x, a):
|
||||
|
|
|
|||
|
|
@ -3584,11 +3584,13 @@ class InstructionTranslatorBase(
|
|||
|
||||
def MATCH_KEYS(self, inst: Instruction) -> None:
|
||||
tos = self.stack[-1]
|
||||
assert isinstance(tos, TupleVariable)
|
||||
keys = tos.unpack_var_sequence(self)
|
||||
tos1 = self.stack[-2]
|
||||
assert isinstance(tos1, ConstDictVariable)
|
||||
|
||||
if all(k in tos1 for k in tos): # type: ignore[attr-defined]
|
||||
self.push(TupleVariable([tos1.getitem_const(self, k) for k in tos])) # type: ignore[attr-defined,arg-type]
|
||||
if all(k in tos1 for k in keys): # type: ignore[attr-defined]
|
||||
self.push(TupleVariable([tos1.getitem_const(self, k) for k in keys])) # type: ignore[attr-defined,arg-type]
|
||||
if sys.version_info < (3, 11):
|
||||
self.push(ConstantVariable.create(True))
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user