mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo, 3.14] fix misc. bugs to get most dynamo unittests passing locally in 3.14 (#164631)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164631 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
This commit is contained in:
parent
ea698e8bfc
commit
f452edd782
|
|
@ -487,6 +487,7 @@ def fn():
|
|||
self.assertIn("JUMP", i1.opname)
|
||||
self.assertIs(i1.target, insts[-1])
|
||||
|
||||
@unittest.skipIf(sys.version_info >= (3, 14), "3.14+ removed RETURN_CONST")
|
||||
@skipIfNotPy312
|
||||
def test_bytecode_from_template_noreturn_const(self):
|
||||
# Test 3.12+ RETURN_CONST
|
||||
|
|
@ -535,7 +536,9 @@ def fn():
|
|||
def test_extended_args_starts_line(self):
|
||||
# NOTE: need to LOAD_CONST i before LOAD_FAST x
|
||||
# in order to get an EXTENDED_ARG with starts_line set
|
||||
lines = "\n".join(f" x = {i} + x" for i in range(300))
|
||||
# NOTE: 3.14+ introduced LOAD_SMALL_INT, so integers need to be >= 256
|
||||
# in order for LOAD_CONST to be generated
|
||||
lines = "\n".join(f" x = {i + 1000} + x" for i in range(300))
|
||||
fn_str = f"def fn(x):\n{lines}"
|
||||
locals = {}
|
||||
exec(fn_str, {}, locals)
|
||||
|
|
|
|||
|
|
@ -390,7 +390,8 @@ class TestDynamoTimed(TestCase):
|
|||
# directly inspect the dict it prints instead:
|
||||
self.assertExpectedInline(
|
||||
pprint.pformat(utils.compilation_time_metrics),
|
||||
"""\
|
||||
(
|
||||
"""\
|
||||
{'GraphLowering.codegen': [0.0, 0.0],
|
||||
'GraphLowering.compile_to_fn': [0.0, 0.0],
|
||||
'GraphLowering.compile_to_module': [0.0, 0.0],
|
||||
|
|
@ -420,8 +421,8 @@ class TestDynamoTimed(TestCase):
|
|||
'fx_codegen_and_compile': [0.0, 0.0],
|
||||
'gc': [0.0],
|
||||
'min_cut_rematerialization_partition': [0.0]}"""
|
||||
if _IS_WINDOWS
|
||||
else """\
|
||||
if _IS_WINDOWS
|
||||
else """\
|
||||
{'GraphLowering.codegen': [0.0, 0.0],
|
||||
'GraphLowering.compile_to_fn': [0.0, 0.0],
|
||||
'GraphLowering.compile_to_module': [0.0, 0.0],
|
||||
|
|
@ -451,7 +452,8 @@ class TestDynamoTimed(TestCase):
|
|||
'create_aot_dispatcher_function': [0.0],
|
||||
'fx_codegen_and_compile': [0.0, 0.0],
|
||||
'gc': [0.0],
|
||||
'min_cut_rematerialization_partition': [0.0]}""", # noqa: B950
|
||||
'min_cut_rematerialization_partition': [0.0]}"""
|
||||
), # noqa: B950
|
||||
)
|
||||
|
||||
# Now validate utils.calculate_time_spent(). Formatting the return
|
||||
|
|
@ -459,7 +461,8 @@ class TestDynamoTimed(TestCase):
|
|||
time_spent = utils.calculate_time_spent()
|
||||
self.assertExpectedInline(
|
||||
pprint.pformat(time_spent),
|
||||
"""\
|
||||
(
|
||||
"""\
|
||||
{'_recursive_joint_graph_passes': 0.0,
|
||||
'_recursive_post_grad_passes': 0.0,
|
||||
'_recursive_pre_grad_passes': 0.0,
|
||||
|
|
@ -470,8 +473,8 @@ class TestDynamoTimed(TestCase):
|
|||
'gc': 0.0,
|
||||
'inductor_compile': 0.0,
|
||||
'total_wall_time': 0.0}"""
|
||||
if _IS_WINDOWS
|
||||
else """\
|
||||
if _IS_WINDOWS
|
||||
else """\
|
||||
{'_recursive_joint_graph_passes': 0.0,
|
||||
'_recursive_post_grad_passes': 0.0,
|
||||
'_recursive_pre_grad_passes': 0.0,
|
||||
|
|
@ -482,7 +485,8 @@ class TestDynamoTimed(TestCase):
|
|||
'entire_frame_compile': 0.0,
|
||||
'gc': 0.0,
|
||||
'inductor_compile': 0.0,
|
||||
'total_wall_time': 0.0}""", # noqa: B950
|
||||
'total_wall_time': 0.0}"""
|
||||
), # noqa: B950
|
||||
)
|
||||
|
||||
# Now validate the CompilationMetrics logs. We expect a log for the
|
||||
|
|
@ -520,7 +524,8 @@ class TestDynamoTimed(TestCase):
|
|||
del raw["guard_latency_us"]
|
||||
self.assertExpectedInline(
|
||||
pprint.pformat(raw),
|
||||
"""\
|
||||
(
|
||||
"""\
|
||||
{'accumulated_cache_size': 0,
|
||||
'aot_autograd_cumulative_compile_time_us': 0,
|
||||
'backend_compile_time_s': 0.0,
|
||||
|
|
@ -606,8 +611,8 @@ class TestDynamoTimed(TestCase):
|
|||
'triton_compile_time_us': None,
|
||||
'triton_kernel_compile_times_us': None,
|
||||
'triton_version': None}"""
|
||||
if _IS_WINDOWS
|
||||
else """\
|
||||
if _IS_WINDOWS
|
||||
else """\
|
||||
{'accumulated_cache_size': 0,
|
||||
'aot_autograd_cumulative_compile_time_us': 0,
|
||||
'backend_compile_time_s': 0.0,
|
||||
|
|
@ -692,7 +697,8 @@ class TestDynamoTimed(TestCase):
|
|||
'tensorify_float_success': None,
|
||||
'triton_compile_time_us': 0,
|
||||
'triton_kernel_compile_times_us': None,
|
||||
'triton_version': None}""", # noqa: B950
|
||||
'triton_version': None}"""
|
||||
), # noqa: B950
|
||||
)
|
||||
|
||||
# Second event is for the backward
|
||||
|
|
@ -706,7 +712,8 @@ class TestDynamoTimed(TestCase):
|
|||
del raw["param_count"]
|
||||
self.assertExpectedInline(
|
||||
pprint.pformat(raw),
|
||||
"""\
|
||||
(
|
||||
"""\
|
||||
{'accumulated_cache_size': None,
|
||||
'aot_autograd_cumulative_compile_time_us': None,
|
||||
'backend_compile_time_s': None,
|
||||
|
|
@ -792,8 +799,8 @@ class TestDynamoTimed(TestCase):
|
|||
'triton_compile_time_us': None,
|
||||
'triton_kernel_compile_times_us': None,
|
||||
'triton_version': None}"""
|
||||
if _IS_WINDOWS
|
||||
else """\
|
||||
if _IS_WINDOWS
|
||||
else """\
|
||||
{'accumulated_cache_size': None,
|
||||
'aot_autograd_cumulative_compile_time_us': None,
|
||||
'backend_compile_time_s': None,
|
||||
|
|
@ -878,7 +885,8 @@ class TestDynamoTimed(TestCase):
|
|||
'tensorify_float_success': None,
|
||||
'triton_compile_time_us': 0,
|
||||
'triton_kernel_compile_times_us': None,
|
||||
'triton_version': None}""", # noqa: B950
|
||||
'triton_version': None}"""
|
||||
), # noqa: B950
|
||||
)
|
||||
|
||||
@dynamo_config.patch(
|
||||
|
|
@ -908,13 +916,14 @@ class TestDynamoTimed(TestCase):
|
|||
def test_ir_count(self):
|
||||
# Different python versions have different potential IR counts.
|
||||
version = (sys.version_info[0], sys.version_info[1])
|
||||
self.assertIn(version, ((3, 9), (3, 10), (3, 11), (3, 12), (3, 13)))
|
||||
self.assertIn(version, ((3, 9), (3, 10), (3, 11), (3, 12), (3, 13), (3, 14)))
|
||||
first, second = {
|
||||
(3, 9): (10, 6),
|
||||
(3, 10): (10, 6),
|
||||
(3, 11): (10, 6),
|
||||
(3, 12): (11, 7),
|
||||
(3, 13): (11, 7),
|
||||
(3, 14): (11, 7),
|
||||
}[version]
|
||||
|
||||
def test1(x):
|
||||
|
|
|
|||
|
|
@ -1431,7 +1431,12 @@ class InstructionTranslatorBase(
|
|||
# an exception table entry, so we also assume that we
|
||||
# are still in the same block. It is probably safe to do
|
||||
# this in 3.11, even though we haven't encountered this case before.
|
||||
if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"):
|
||||
# In 3.14+, NOT_TAKEN might also not be covered by an exn table entry.
|
||||
if self.block_stack and inst.opname not in (
|
||||
"NOP",
|
||||
"JUMP_BACKWARD",
|
||||
"NOT_TAKEN",
|
||||
):
|
||||
# If we really escape from a block and the current
|
||||
# instruction is not in another block, then there
|
||||
# should be no other nested blocks that we are in.
|
||||
|
|
|
|||
|
|
@ -1402,8 +1402,9 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
|
||||
same_spec = _make_inlined(tx, pytree.TreeSpec.__eq__)(
|
||||
true_spec.treespec, false_spec.treespec
|
||||
)
|
||||
if not same_spec.as_python_constant():
|
||||
).as_python_constant()
|
||||
# 3.14: NotImplemented cannot be converted to bool
|
||||
if same_spec is not NotImplemented and not same_spec:
|
||||
unimplemented("Expected branches to return the same pytree structure.")
|
||||
|
||||
(
|
||||
|
|
@ -1696,9 +1697,11 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
|
||||
with tx.fake_mode:
|
||||
sub_args_fake = [
|
||||
leaf.node.meta["example_value"].clone()
|
||||
if hasattr(leaf.node.meta["example_value"], "clone")
|
||||
else leaf.node.meta["example_value"]
|
||||
(
|
||||
leaf.node.meta["example_value"].clone()
|
||||
if hasattr(leaf.node.meta["example_value"], "clone")
|
||||
else leaf.node.meta["example_value"]
|
||||
)
|
||||
for leaf in pytree.tree_leaves(proxy_vars_inputcheck)
|
||||
]
|
||||
pre_dispatch = False
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import ast
|
||||
import inspect
|
||||
import sys
|
||||
import textwrap
|
||||
import warnings
|
||||
|
||||
|
|
@ -74,7 +75,16 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
|
|||
init_ast = ast.parse(textwrap.dedent(source_lines))
|
||||
|
||||
# Get items annotated in the class body
|
||||
self.class_level_annotations = list(nn_module.__annotations__.keys())
|
||||
if sys.version_info >= (3, 14):
|
||||
import annotationlib
|
||||
|
||||
self.class_level_annotations = list(
|
||||
annotationlib.get_annotations(
|
||||
nn_module, format=annotationlib.Format.FORWARDREF
|
||||
).keys()
|
||||
)
|
||||
else:
|
||||
self.class_level_annotations = list(nn_module.__annotations__.keys())
|
||||
|
||||
# Flag for later
|
||||
self.visiting_class_level_ann = False
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user