[dynamo, 3.14] fix misc. bugs to get most dynamo unittests passing locally in 3.14 (#164631)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164631
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
This commit is contained in:
William Wen 2025-10-27 15:56:11 -07:00 committed by PyTorch MergeBot
parent ea698e8bfc
commit f452edd782
5 changed files with 55 additions and 25 deletions

View File

@ -487,6 +487,7 @@ def fn():
self.assertIn("JUMP", i1.opname)
self.assertIs(i1.target, insts[-1])
@unittest.skipIf(sys.version_info >= (3, 14), "3.14+ removed RETURN_CONST")
@skipIfNotPy312
def test_bytecode_from_template_noreturn_const(self):
# Test 3.12+ RETURN_CONST
@ -535,7 +536,9 @@ def fn():
def test_extended_args_starts_line(self):
# NOTE: need to LOAD_CONST i before LOAD_FAST x
# in order to get an EXTENDED_ARG with starts_line set
lines = "\n".join(f" x = {i} + x" for i in range(300))
# NOTE: 3.14+ introduced LOAD_SMALL_INT, so integers need to be >= 256
# in order for LOAD_CONST to be generated
lines = "\n".join(f" x = {i + 1000} + x" for i in range(300))
fn_str = f"def fn(x):\n{lines}"
locals = {}
exec(fn_str, {}, locals)

View File

@ -390,6 +390,7 @@ class TestDynamoTimed(TestCase):
# directly inspect the dict it prints instead:
self.assertExpectedInline(
pprint.pformat(utils.compilation_time_metrics),
(
"""\
{'GraphLowering.codegen': [0.0, 0.0],
'GraphLowering.compile_to_fn': [0.0, 0.0],
@ -451,7 +452,8 @@ class TestDynamoTimed(TestCase):
'create_aot_dispatcher_function': [0.0],
'fx_codegen_and_compile': [0.0, 0.0],
'gc': [0.0],
'min_cut_rematerialization_partition': [0.0]}""", # noqa: B950
'min_cut_rematerialization_partition': [0.0]}"""
), # noqa: B950
)
# Now validate utils.calculate_time_spent(). Formatting the return
@ -459,6 +461,7 @@ class TestDynamoTimed(TestCase):
time_spent = utils.calculate_time_spent()
self.assertExpectedInline(
pprint.pformat(time_spent),
(
"""\
{'_recursive_joint_graph_passes': 0.0,
'_recursive_post_grad_passes': 0.0,
@ -482,7 +485,8 @@ class TestDynamoTimed(TestCase):
'entire_frame_compile': 0.0,
'gc': 0.0,
'inductor_compile': 0.0,
'total_wall_time': 0.0}""", # noqa: B950
'total_wall_time': 0.0}"""
), # noqa: B950
)
# Now validate the CompilationMetrics logs. We expect a log for the
@ -520,6 +524,7 @@ class TestDynamoTimed(TestCase):
del raw["guard_latency_us"]
self.assertExpectedInline(
pprint.pformat(raw),
(
"""\
{'accumulated_cache_size': 0,
'aot_autograd_cumulative_compile_time_us': 0,
@ -692,7 +697,8 @@ class TestDynamoTimed(TestCase):
'tensorify_float_success': None,
'triton_compile_time_us': 0,
'triton_kernel_compile_times_us': None,
'triton_version': None}""", # noqa: B950
'triton_version': None}"""
), # noqa: B950
)
# Second event is for the backward
@ -706,6 +712,7 @@ class TestDynamoTimed(TestCase):
del raw["param_count"]
self.assertExpectedInline(
pprint.pformat(raw),
(
"""\
{'accumulated_cache_size': None,
'aot_autograd_cumulative_compile_time_us': None,
@ -878,7 +885,8 @@ class TestDynamoTimed(TestCase):
'tensorify_float_success': None,
'triton_compile_time_us': 0,
'triton_kernel_compile_times_us': None,
'triton_version': None}""", # noqa: B950
'triton_version': None}"""
), # noqa: B950
)
@dynamo_config.patch(
@ -908,13 +916,14 @@ class TestDynamoTimed(TestCase):
def test_ir_count(self):
# Different python versions have different potential IR counts.
version = (sys.version_info[0], sys.version_info[1])
self.assertIn(version, ((3, 9), (3, 10), (3, 11), (3, 12), (3, 13)))
self.assertIn(version, ((3, 9), (3, 10), (3, 11), (3, 12), (3, 13), (3, 14)))
first, second = {
(3, 9): (10, 6),
(3, 10): (10, 6),
(3, 11): (10, 6),
(3, 12): (11, 7),
(3, 13): (11, 7),
(3, 14): (11, 7),
}[version]
def test1(x):

View File

@ -1431,7 +1431,12 @@ class InstructionTranslatorBase(
# an exception table entry, so we also assume that we
# are still in the same block. It is probably safe to do
# this in 3.11, even though we haven't encountered this case before.
if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"):
# In 3.14+, NOT_TAKEN might also not be covered by an exn table entry.
if self.block_stack and inst.opname not in (
"NOP",
"JUMP_BACKWARD",
"NOT_TAKEN",
):
# If we really escape from a block and the current
# instruction is not in another block, then there
# should be no other nested blocks that we are in.

View File

@ -1402,8 +1402,9 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
same_spec = _make_inlined(tx, pytree.TreeSpec.__eq__)(
true_spec.treespec, false_spec.treespec
)
if not same_spec.as_python_constant():
).as_python_constant()
# 3.14: NotImplemented cannot be converted to bool
if same_spec is not NotImplemented and not same_spec:
unimplemented("Expected branches to return the same pytree structure.")
(
@ -1696,9 +1697,11 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
with tx.fake_mode:
sub_args_fake = [
(
leaf.node.meta["example_value"].clone()
if hasattr(leaf.node.meta["example_value"], "clone")
else leaf.node.meta["example_value"]
)
for leaf in pytree.tree_leaves(proxy_vars_inputcheck)
]
pre_dispatch = False

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import ast
import inspect
import sys
import textwrap
import warnings
@ -74,6 +75,15 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
init_ast = ast.parse(textwrap.dedent(source_lines))
# Get items annotated in the class body
if sys.version_info >= (3, 14):
import annotationlib
self.class_level_annotations = list(
annotationlib.get_annotations(
nn_module, format=annotationlib.Format.FORWARDREF
).keys()
)
else:
self.class_level_annotations = list(nn_module.__annotations__.keys())
# Flag for later