mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo, 3.14] fix misc. bugs to get most dynamo unittests passing locally in 3.14 (#164631)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164631 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
This commit is contained in:
parent
ea698e8bfc
commit
f452edd782
|
|
@ -487,6 +487,7 @@ def fn():
|
||||||
self.assertIn("JUMP", i1.opname)
|
self.assertIn("JUMP", i1.opname)
|
||||||
self.assertIs(i1.target, insts[-1])
|
self.assertIs(i1.target, insts[-1])
|
||||||
|
|
||||||
|
@unittest.skipIf(sys.version_info >= (3, 14), "3.14+ removed RETURN_CONST")
|
||||||
@skipIfNotPy312
|
@skipIfNotPy312
|
||||||
def test_bytecode_from_template_noreturn_const(self):
|
def test_bytecode_from_template_noreturn_const(self):
|
||||||
# Test 3.12+ RETURN_CONST
|
# Test 3.12+ RETURN_CONST
|
||||||
|
|
@ -535,7 +536,9 @@ def fn():
|
||||||
def test_extended_args_starts_line(self):
|
def test_extended_args_starts_line(self):
|
||||||
# NOTE: need to LOAD_CONST i before LOAD_FAST x
|
# NOTE: need to LOAD_CONST i before LOAD_FAST x
|
||||||
# in order to get an EXTENDED_ARG with starts_line set
|
# in order to get an EXTENDED_ARG with starts_line set
|
||||||
lines = "\n".join(f" x = {i} + x" for i in range(300))
|
# NOTE: 3.14+ introduced LOAD_SMALL_INT, so integers need to be >= 256
|
||||||
|
# in order for LOAD_CONST to be generated
|
||||||
|
lines = "\n".join(f" x = {i + 1000} + x" for i in range(300))
|
||||||
fn_str = f"def fn(x):\n{lines}"
|
fn_str = f"def fn(x):\n{lines}"
|
||||||
locals = {}
|
locals = {}
|
||||||
exec(fn_str, {}, locals)
|
exec(fn_str, {}, locals)
|
||||||
|
|
|
||||||
|
|
@ -390,6 +390,7 @@ class TestDynamoTimed(TestCase):
|
||||||
# directly inspect the dict it prints instead:
|
# directly inspect the dict it prints instead:
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
pprint.pformat(utils.compilation_time_metrics),
|
pprint.pformat(utils.compilation_time_metrics),
|
||||||
|
(
|
||||||
"""\
|
"""\
|
||||||
{'GraphLowering.codegen': [0.0, 0.0],
|
{'GraphLowering.codegen': [0.0, 0.0],
|
||||||
'GraphLowering.compile_to_fn': [0.0, 0.0],
|
'GraphLowering.compile_to_fn': [0.0, 0.0],
|
||||||
|
|
@ -451,7 +452,8 @@ class TestDynamoTimed(TestCase):
|
||||||
'create_aot_dispatcher_function': [0.0],
|
'create_aot_dispatcher_function': [0.0],
|
||||||
'fx_codegen_and_compile': [0.0, 0.0],
|
'fx_codegen_and_compile': [0.0, 0.0],
|
||||||
'gc': [0.0],
|
'gc': [0.0],
|
||||||
'min_cut_rematerialization_partition': [0.0]}""", # noqa: B950
|
'min_cut_rematerialization_partition': [0.0]}"""
|
||||||
|
), # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now validate utils.calculate_time_spent(). Formatting the return
|
# Now validate utils.calculate_time_spent(). Formatting the return
|
||||||
|
|
@ -459,6 +461,7 @@ class TestDynamoTimed(TestCase):
|
||||||
time_spent = utils.calculate_time_spent()
|
time_spent = utils.calculate_time_spent()
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
pprint.pformat(time_spent),
|
pprint.pformat(time_spent),
|
||||||
|
(
|
||||||
"""\
|
"""\
|
||||||
{'_recursive_joint_graph_passes': 0.0,
|
{'_recursive_joint_graph_passes': 0.0,
|
||||||
'_recursive_post_grad_passes': 0.0,
|
'_recursive_post_grad_passes': 0.0,
|
||||||
|
|
@ -482,7 +485,8 @@ class TestDynamoTimed(TestCase):
|
||||||
'entire_frame_compile': 0.0,
|
'entire_frame_compile': 0.0,
|
||||||
'gc': 0.0,
|
'gc': 0.0,
|
||||||
'inductor_compile': 0.0,
|
'inductor_compile': 0.0,
|
||||||
'total_wall_time': 0.0}""", # noqa: B950
|
'total_wall_time': 0.0}"""
|
||||||
|
), # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now validate the CompilationMetrics logs. We expect a log for the
|
# Now validate the CompilationMetrics logs. We expect a log for the
|
||||||
|
|
@ -520,6 +524,7 @@ class TestDynamoTimed(TestCase):
|
||||||
del raw["guard_latency_us"]
|
del raw["guard_latency_us"]
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
pprint.pformat(raw),
|
pprint.pformat(raw),
|
||||||
|
(
|
||||||
"""\
|
"""\
|
||||||
{'accumulated_cache_size': 0,
|
{'accumulated_cache_size': 0,
|
||||||
'aot_autograd_cumulative_compile_time_us': 0,
|
'aot_autograd_cumulative_compile_time_us': 0,
|
||||||
|
|
@ -692,7 +697,8 @@ class TestDynamoTimed(TestCase):
|
||||||
'tensorify_float_success': None,
|
'tensorify_float_success': None,
|
||||||
'triton_compile_time_us': 0,
|
'triton_compile_time_us': 0,
|
||||||
'triton_kernel_compile_times_us': None,
|
'triton_kernel_compile_times_us': None,
|
||||||
'triton_version': None}""", # noqa: B950
|
'triton_version': None}"""
|
||||||
|
), # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
# Second event is for the backward
|
# Second event is for the backward
|
||||||
|
|
@ -706,6 +712,7 @@ class TestDynamoTimed(TestCase):
|
||||||
del raw["param_count"]
|
del raw["param_count"]
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
pprint.pformat(raw),
|
pprint.pformat(raw),
|
||||||
|
(
|
||||||
"""\
|
"""\
|
||||||
{'accumulated_cache_size': None,
|
{'accumulated_cache_size': None,
|
||||||
'aot_autograd_cumulative_compile_time_us': None,
|
'aot_autograd_cumulative_compile_time_us': None,
|
||||||
|
|
@ -878,7 +885,8 @@ class TestDynamoTimed(TestCase):
|
||||||
'tensorify_float_success': None,
|
'tensorify_float_success': None,
|
||||||
'triton_compile_time_us': 0,
|
'triton_compile_time_us': 0,
|
||||||
'triton_kernel_compile_times_us': None,
|
'triton_kernel_compile_times_us': None,
|
||||||
'triton_version': None}""", # noqa: B950
|
'triton_version': None}"""
|
||||||
|
), # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
@dynamo_config.patch(
|
@dynamo_config.patch(
|
||||||
|
|
@ -908,13 +916,14 @@ class TestDynamoTimed(TestCase):
|
||||||
def test_ir_count(self):
|
def test_ir_count(self):
|
||||||
# Different python versions have different potential IR counts.
|
# Different python versions have different potential IR counts.
|
||||||
version = (sys.version_info[0], sys.version_info[1])
|
version = (sys.version_info[0], sys.version_info[1])
|
||||||
self.assertIn(version, ((3, 9), (3, 10), (3, 11), (3, 12), (3, 13)))
|
self.assertIn(version, ((3, 9), (3, 10), (3, 11), (3, 12), (3, 13), (3, 14)))
|
||||||
first, second = {
|
first, second = {
|
||||||
(3, 9): (10, 6),
|
(3, 9): (10, 6),
|
||||||
(3, 10): (10, 6),
|
(3, 10): (10, 6),
|
||||||
(3, 11): (10, 6),
|
(3, 11): (10, 6),
|
||||||
(3, 12): (11, 7),
|
(3, 12): (11, 7),
|
||||||
(3, 13): (11, 7),
|
(3, 13): (11, 7),
|
||||||
|
(3, 14): (11, 7),
|
||||||
}[version]
|
}[version]
|
||||||
|
|
||||||
def test1(x):
|
def test1(x):
|
||||||
|
|
|
||||||
|
|
@ -1431,7 +1431,12 @@ class InstructionTranslatorBase(
|
||||||
# an exception table entry, so we also assume that we
|
# an exception table entry, so we also assume that we
|
||||||
# are still in the same block. It is probably safe to do
|
# are still in the same block. It is probably safe to do
|
||||||
# this in 3.11, even though we haven't encountered this case before.
|
# this in 3.11, even though we haven't encountered this case before.
|
||||||
if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"):
|
# In 3.14+, NOT_TAKEN might also not be covered by an exn table entry.
|
||||||
|
if self.block_stack and inst.opname not in (
|
||||||
|
"NOP",
|
||||||
|
"JUMP_BACKWARD",
|
||||||
|
"NOT_TAKEN",
|
||||||
|
):
|
||||||
# If we really escape from a block and the current
|
# If we really escape from a block and the current
|
||||||
# instruction is not in another block, then there
|
# instruction is not in another block, then there
|
||||||
# should be no other nested blocks that we are in.
|
# should be no other nested blocks that we are in.
|
||||||
|
|
|
||||||
|
|
@ -1402,8 +1402,9 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
|
||||||
same_spec = _make_inlined(tx, pytree.TreeSpec.__eq__)(
|
same_spec = _make_inlined(tx, pytree.TreeSpec.__eq__)(
|
||||||
true_spec.treespec, false_spec.treespec
|
true_spec.treespec, false_spec.treespec
|
||||||
)
|
).as_python_constant()
|
||||||
if not same_spec.as_python_constant():
|
# 3.14: NotImplemented cannot be converted to bool
|
||||||
|
if same_spec is not NotImplemented and not same_spec:
|
||||||
unimplemented("Expected branches to return the same pytree structure.")
|
unimplemented("Expected branches to return the same pytree structure.")
|
||||||
|
|
||||||
(
|
(
|
||||||
|
|
@ -1696,9 +1697,11 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
|
||||||
with tx.fake_mode:
|
with tx.fake_mode:
|
||||||
sub_args_fake = [
|
sub_args_fake = [
|
||||||
|
(
|
||||||
leaf.node.meta["example_value"].clone()
|
leaf.node.meta["example_value"].clone()
|
||||||
if hasattr(leaf.node.meta["example_value"], "clone")
|
if hasattr(leaf.node.meta["example_value"], "clone")
|
||||||
else leaf.node.meta["example_value"]
|
else leaf.node.meta["example_value"]
|
||||||
|
)
|
||||||
for leaf in pytree.tree_leaves(proxy_vars_inputcheck)
|
for leaf in pytree.tree_leaves(proxy_vars_inputcheck)
|
||||||
]
|
]
|
||||||
pre_dispatch = False
|
pre_dispatch = False
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
@ -74,6 +75,15 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
|
||||||
init_ast = ast.parse(textwrap.dedent(source_lines))
|
init_ast = ast.parse(textwrap.dedent(source_lines))
|
||||||
|
|
||||||
# Get items annotated in the class body
|
# Get items annotated in the class body
|
||||||
|
if sys.version_info >= (3, 14):
|
||||||
|
import annotationlib
|
||||||
|
|
||||||
|
self.class_level_annotations = list(
|
||||||
|
annotationlib.get_annotations(
|
||||||
|
nn_module, format=annotationlib.Format.FORWARDREF
|
||||||
|
).keys()
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.class_level_annotations = list(nn_module.__annotations__.keys())
|
self.class_level_annotations = list(nn_module.__annotations__.keys())
|
||||||
|
|
||||||
# Flag for later
|
# Flag for later
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user