[dynamo, 3.14] fix misc. bugs to get most dynamo unittests passing locally in 3.14 (#164631)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164631
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
This commit is contained in:
William Wen 2025-10-27 15:56:11 -07:00 committed by PyTorch MergeBot
parent ea698e8bfc
commit f452edd782
5 changed files with 55 additions and 25 deletions

View File

@ -487,6 +487,7 @@ def fn():
self.assertIn("JUMP", i1.opname) self.assertIn("JUMP", i1.opname)
self.assertIs(i1.target, insts[-1]) self.assertIs(i1.target, insts[-1])
@unittest.skipIf(sys.version_info >= (3, 14), "3.14+ removed RETURN_CONST")
@skipIfNotPy312 @skipIfNotPy312
def test_bytecode_from_template_noreturn_const(self): def test_bytecode_from_template_noreturn_const(self):
# Test 3.12+ RETURN_CONST # Test 3.12+ RETURN_CONST
@ -535,7 +536,9 @@ def fn():
def test_extended_args_starts_line(self): def test_extended_args_starts_line(self):
# NOTE: need to LOAD_CONST i before LOAD_FAST x # NOTE: need to LOAD_CONST i before LOAD_FAST x
# in order to get an EXTENDED_ARG with starts_line set # in order to get an EXTENDED_ARG with starts_line set
lines = "\n".join(f" x = {i} + x" for i in range(300)) # NOTE: 3.14+ introduced LOAD_SMALL_INT, so integers need to be >= 256
# in order for LOAD_CONST to be generated
lines = "\n".join(f" x = {i + 1000} + x" for i in range(300))
fn_str = f"def fn(x):\n{lines}" fn_str = f"def fn(x):\n{lines}"
locals = {} locals = {}
exec(fn_str, {}, locals) exec(fn_str, {}, locals)

View File

@ -390,7 +390,8 @@ class TestDynamoTimed(TestCase):
# directly inspect the dict it prints instead: # directly inspect the dict it prints instead:
self.assertExpectedInline( self.assertExpectedInline(
pprint.pformat(utils.compilation_time_metrics), pprint.pformat(utils.compilation_time_metrics),
"""\ (
"""\
{'GraphLowering.codegen': [0.0, 0.0], {'GraphLowering.codegen': [0.0, 0.0],
'GraphLowering.compile_to_fn': [0.0, 0.0], 'GraphLowering.compile_to_fn': [0.0, 0.0],
'GraphLowering.compile_to_module': [0.0, 0.0], 'GraphLowering.compile_to_module': [0.0, 0.0],
@ -420,8 +421,8 @@ class TestDynamoTimed(TestCase):
'fx_codegen_and_compile': [0.0, 0.0], 'fx_codegen_and_compile': [0.0, 0.0],
'gc': [0.0], 'gc': [0.0],
'min_cut_rematerialization_partition': [0.0]}""" 'min_cut_rematerialization_partition': [0.0]}"""
if _IS_WINDOWS if _IS_WINDOWS
else """\ else """\
{'GraphLowering.codegen': [0.0, 0.0], {'GraphLowering.codegen': [0.0, 0.0],
'GraphLowering.compile_to_fn': [0.0, 0.0], 'GraphLowering.compile_to_fn': [0.0, 0.0],
'GraphLowering.compile_to_module': [0.0, 0.0], 'GraphLowering.compile_to_module': [0.0, 0.0],
@ -451,7 +452,8 @@ class TestDynamoTimed(TestCase):
'create_aot_dispatcher_function': [0.0], 'create_aot_dispatcher_function': [0.0],
'fx_codegen_and_compile': [0.0, 0.0], 'fx_codegen_and_compile': [0.0, 0.0],
'gc': [0.0], 'gc': [0.0],
'min_cut_rematerialization_partition': [0.0]}""", # noqa: B950 'min_cut_rematerialization_partition': [0.0]}"""
), # noqa: B950
) )
# Now validate utils.calculate_time_spent(). Formatting the return # Now validate utils.calculate_time_spent(). Formatting the return
@ -459,7 +461,8 @@ class TestDynamoTimed(TestCase):
time_spent = utils.calculate_time_spent() time_spent = utils.calculate_time_spent()
self.assertExpectedInline( self.assertExpectedInline(
pprint.pformat(time_spent), pprint.pformat(time_spent),
"""\ (
"""\
{'_recursive_joint_graph_passes': 0.0, {'_recursive_joint_graph_passes': 0.0,
'_recursive_post_grad_passes': 0.0, '_recursive_post_grad_passes': 0.0,
'_recursive_pre_grad_passes': 0.0, '_recursive_pre_grad_passes': 0.0,
@ -470,8 +473,8 @@ class TestDynamoTimed(TestCase):
'gc': 0.0, 'gc': 0.0,
'inductor_compile': 0.0, 'inductor_compile': 0.0,
'total_wall_time': 0.0}""" 'total_wall_time': 0.0}"""
if _IS_WINDOWS if _IS_WINDOWS
else """\ else """\
{'_recursive_joint_graph_passes': 0.0, {'_recursive_joint_graph_passes': 0.0,
'_recursive_post_grad_passes': 0.0, '_recursive_post_grad_passes': 0.0,
'_recursive_pre_grad_passes': 0.0, '_recursive_pre_grad_passes': 0.0,
@ -482,7 +485,8 @@ class TestDynamoTimed(TestCase):
'entire_frame_compile': 0.0, 'entire_frame_compile': 0.0,
'gc': 0.0, 'gc': 0.0,
'inductor_compile': 0.0, 'inductor_compile': 0.0,
'total_wall_time': 0.0}""", # noqa: B950 'total_wall_time': 0.0}"""
), # noqa: B950
) )
# Now validate the CompilationMetrics logs. We expect a log for the # Now validate the CompilationMetrics logs. We expect a log for the
@ -520,7 +524,8 @@ class TestDynamoTimed(TestCase):
del raw["guard_latency_us"] del raw["guard_latency_us"]
self.assertExpectedInline( self.assertExpectedInline(
pprint.pformat(raw), pprint.pformat(raw),
"""\ (
"""\
{'accumulated_cache_size': 0, {'accumulated_cache_size': 0,
'aot_autograd_cumulative_compile_time_us': 0, 'aot_autograd_cumulative_compile_time_us': 0,
'backend_compile_time_s': 0.0, 'backend_compile_time_s': 0.0,
@ -606,8 +611,8 @@ class TestDynamoTimed(TestCase):
'triton_compile_time_us': None, 'triton_compile_time_us': None,
'triton_kernel_compile_times_us': None, 'triton_kernel_compile_times_us': None,
'triton_version': None}""" 'triton_version': None}"""
if _IS_WINDOWS if _IS_WINDOWS
else """\ else """\
{'accumulated_cache_size': 0, {'accumulated_cache_size': 0,
'aot_autograd_cumulative_compile_time_us': 0, 'aot_autograd_cumulative_compile_time_us': 0,
'backend_compile_time_s': 0.0, 'backend_compile_time_s': 0.0,
@ -692,7 +697,8 @@ class TestDynamoTimed(TestCase):
'tensorify_float_success': None, 'tensorify_float_success': None,
'triton_compile_time_us': 0, 'triton_compile_time_us': 0,
'triton_kernel_compile_times_us': None, 'triton_kernel_compile_times_us': None,
'triton_version': None}""", # noqa: B950 'triton_version': None}"""
), # noqa: B950
) )
# Second event is for the backward # Second event is for the backward
@ -706,7 +712,8 @@ class TestDynamoTimed(TestCase):
del raw["param_count"] del raw["param_count"]
self.assertExpectedInline( self.assertExpectedInline(
pprint.pformat(raw), pprint.pformat(raw),
"""\ (
"""\
{'accumulated_cache_size': None, {'accumulated_cache_size': None,
'aot_autograd_cumulative_compile_time_us': None, 'aot_autograd_cumulative_compile_time_us': None,
'backend_compile_time_s': None, 'backend_compile_time_s': None,
@ -792,8 +799,8 @@ class TestDynamoTimed(TestCase):
'triton_compile_time_us': None, 'triton_compile_time_us': None,
'triton_kernel_compile_times_us': None, 'triton_kernel_compile_times_us': None,
'triton_version': None}""" 'triton_version': None}"""
if _IS_WINDOWS if _IS_WINDOWS
else """\ else """\
{'accumulated_cache_size': None, {'accumulated_cache_size': None,
'aot_autograd_cumulative_compile_time_us': None, 'aot_autograd_cumulative_compile_time_us': None,
'backend_compile_time_s': None, 'backend_compile_time_s': None,
@ -878,7 +885,8 @@ class TestDynamoTimed(TestCase):
'tensorify_float_success': None, 'tensorify_float_success': None,
'triton_compile_time_us': 0, 'triton_compile_time_us': 0,
'triton_kernel_compile_times_us': None, 'triton_kernel_compile_times_us': None,
'triton_version': None}""", # noqa: B950 'triton_version': None}"""
), # noqa: B950
) )
@dynamo_config.patch( @dynamo_config.patch(
@ -908,13 +916,14 @@ class TestDynamoTimed(TestCase):
def test_ir_count(self): def test_ir_count(self):
# Different python versions have different potential IR counts. # Different python versions have different potential IR counts.
version = (sys.version_info[0], sys.version_info[1]) version = (sys.version_info[0], sys.version_info[1])
self.assertIn(version, ((3, 9), (3, 10), (3, 11), (3, 12), (3, 13))) self.assertIn(version, ((3, 9), (3, 10), (3, 11), (3, 12), (3, 13), (3, 14)))
first, second = { first, second = {
(3, 9): (10, 6), (3, 9): (10, 6),
(3, 10): (10, 6), (3, 10): (10, 6),
(3, 11): (10, 6), (3, 11): (10, 6),
(3, 12): (11, 7), (3, 12): (11, 7),
(3, 13): (11, 7), (3, 13): (11, 7),
(3, 14): (11, 7),
}[version] }[version]
def test1(x): def test1(x):

View File

@ -1431,7 +1431,12 @@ class InstructionTranslatorBase(
# an exception table entry, so we also assume that we # an exception table entry, so we also assume that we
# are still in the same block. It is probably safe to do # are still in the same block. It is probably safe to do
# this in 3.11, even though we haven't encountered this case before. # this in 3.11, even though we haven't encountered this case before.
if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"): # In 3.14+, NOT_TAKEN might also not be covered by an exn table entry.
if self.block_stack and inst.opname not in (
"NOP",
"JUMP_BACKWARD",
"NOT_TAKEN",
):
# If we really escape from a block and the current # If we really escape from a block and the current
# instruction is not in another block, then there # instruction is not in another block, then there
# should be no other nested blocks that we are in. # should be no other nested blocks that we are in.

View File

@ -1402,8 +1402,9 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
same_spec = _make_inlined(tx, pytree.TreeSpec.__eq__)( same_spec = _make_inlined(tx, pytree.TreeSpec.__eq__)(
true_spec.treespec, false_spec.treespec true_spec.treespec, false_spec.treespec
) ).as_python_constant()
if not same_spec.as_python_constant(): # 3.14: NotImplemented cannot be converted to bool
if same_spec is not NotImplemented and not same_spec:
unimplemented("Expected branches to return the same pytree structure.") unimplemented("Expected branches to return the same pytree structure.")
( (
@ -1696,9 +1697,11 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
with tx.fake_mode: with tx.fake_mode:
sub_args_fake = [ sub_args_fake = [
leaf.node.meta["example_value"].clone() (
if hasattr(leaf.node.meta["example_value"], "clone") leaf.node.meta["example_value"].clone()
else leaf.node.meta["example_value"] if hasattr(leaf.node.meta["example_value"], "clone")
else leaf.node.meta["example_value"]
)
for leaf in pytree.tree_leaves(proxy_vars_inputcheck) for leaf in pytree.tree_leaves(proxy_vars_inputcheck)
] ]
pre_dispatch = False pre_dispatch = False

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import ast import ast
import inspect import inspect
import sys
import textwrap import textwrap
import warnings import warnings
@ -74,7 +75,16 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
init_ast = ast.parse(textwrap.dedent(source_lines)) init_ast = ast.parse(textwrap.dedent(source_lines))
# Get items annotated in the class body # Get items annotated in the class body
self.class_level_annotations = list(nn_module.__annotations__.keys()) if sys.version_info >= (3, 14):
import annotationlib
self.class_level_annotations = list(
annotationlib.get_annotations(
nn_module, format=annotationlib.Format.FORWARDREF
).keys()
)
else:
self.class_level_annotations = list(nn_module.__annotations__.keys())
# Flag for later # Flag for later
self.visiting_class_level_ann = False self.visiting_class_level_ann = False