[dynamo, nested graph breaks] clean up comments and codegen (#160138)

Fix comments to reflect that we no longer codegen cells to be sent to resume function as inputs - they are instead codegen'd after the unsupported instruction in order to build resume functions that are closures.

Also simplify some codegen.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160138
Approved by: https://github.com/anijain2305
ghstack dependencies: #159329, #159678, #159817
This commit is contained in:
William Wen 2025-08-27 10:30:09 -07:00 committed by PyTorch MergeBot
parent d0a242e547
commit 6562646dab
5 changed files with 247 additions and 222 deletions

View File

@ -105,6 +105,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 6)
def test_single_graph_break_repeat(self):
global f1, f2, f3
@ -129,6 +130,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 10)
def test_doubly_nested_graph_break(self):
global f1, f2, f3
@ -153,6 +155,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 7)
def test_differing_arg_nums(self):
global f1, f2, f3, f4
@ -178,6 +181,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 10)
def test_differing_locals_nums(self):
global f1, f2, f3
@ -206,6 +210,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 14)
def test_supported_ctx_manager(self):
global check, check_disabled, f1, f2, f3
@ -248,6 +253,8 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 4)
# includes set_grad_enabled ops
self.assertEqual(cnts.op_count, 14)
def test_inactive_ctx_manager(self):
global check, f1, f2, f3
@ -295,6 +302,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 7)
def test_cells(self):
def f1(x1):
@ -327,6 +335,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 13)
def test_side_effects_cells(self):
cell1, cell2, cell3, cell4 = (torch.zeros(3),) * 4
@ -364,6 +373,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 5)
def test_side_effects_globals(self):
global f1, f2, f3
@ -401,6 +411,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 6)
def test_side_effects_globals_different_module(self):
global f1, f2, _test_nested_graph_breaks_helper
@ -431,6 +442,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 7)
@unittest.expectedFailure
def test_nested_graph_break_in_loop(self):

View File

@ -508,6 +508,43 @@ def create_binary_slice(
]
def create_copy(i: int) -> list[Instruction]:
if sys.version_info >= (3, 11):
return [create_instruction("COPY", arg=i)]
# COPY 4
# 0 1 2 3
# 3 1 2 0
# 3 1 2 0 0
# 0 1 2 0 3
# 0 1 2 3 0
return [
*create_swap(i),
create_dup_top(),
*create_swap(i + 1),
*create_swap(2),
]
# mainly for debugging generated bytecode
def create_print_on_stack(depth: int) -> list[Instruction]:
return [
*add_push_null(create_instruction("LOAD_CONST", argval=print)),
*create_copy(depth + (2 if sys.version_info >= (3, 11) else 1)),
*create_call_function(1, False),
create_instruction("POP_TOP"),
]
# mainly for debugging generated bytecode
def create_print_value(value: Any) -> list[Instruction]:
return [
*add_push_null(create_instruction("LOAD_CONST", argval=print)),
create_instruction("LOAD_CONST", argval=value),
*create_call_function(1, False),
create_instruction("POP_TOP"),
]
def lnotab_writer(
lineno: int, byteno: int = 0
) -> tuple[list[int], Callable[[int, int], None]]:

View File

@ -356,7 +356,6 @@ class StackLocalsMetadata:
locals_names: dict[str, int] = dc_field(
default_factory=dict
) # order of locals codegen'd to the stack
cell_and_freevars: dict[str, int] = dc_field(default_factory=dict)
stack_null_idxes: list[int] = dc_field(default_factory=list)
locals_null_keys: list[str] = dc_field(default_factory=list)
stack_ctx_args: list[tuple[int, tuple[Any, ...]]] = dc_field(default_factory=list)
@ -1237,10 +1236,7 @@ class OutputGraph(OutputGraphGuardsState):
meta.num_stack = len(stack_values)
cell_and_freevars = dict.fromkeys(tx.cellvars() + tx.freevars())
meta.cell_and_freevars = {
name: i for i, name in enumerate(cell_and_freevars.keys())
}
cell_and_freevars = set(tx.cellvars() + tx.freevars())
# NB: Typically (i.e., for graph compile from RETURN_VALUE),
# symbolic_locals will be empty at this point, as prune_dead_locals
@ -1256,7 +1252,8 @@ class OutputGraph(OutputGraphGuardsState):
# This will in turn result in spurious variables showing up in the graph.
# This was very tricky to debug. For an example, dump the graph at call_user_compiler
# while running test_subgraphs.py
# Do not load unmodified locals (load them at a later time) from the top frame
# Do not include top-frame unmodified locals here - otherwise, the compiled graph may
# erroneously include them as part of the return. We manually codegen them afterward.
if (
isinstance(v.source, LocalSource)
and v.source.local_name == k
@ -1264,7 +1261,7 @@ class OutputGraph(OutputGraphGuardsState):
):
continue
# Do not load cell/free vars
if k in meta.cell_and_freevars:
if k in cell_and_freevars:
continue
# Do not load variable if it is NULL.
if sys.version_info >= (3, 12):
@ -1338,12 +1335,12 @@ class OutputGraph(OutputGraphGuardsState):
prefix_insts.append(copy.copy(inst))
# stack values and restore vars for each frame are pushed in reverse order
# i.e. last element corresponds to root frame, first element corresponds to current frame
# i.e. last element corresponds to root frame (1),
# first element corresponds to current frame (N)
all_stack_values = []
all_stack_locals_metas = []
cur_tx: Optional[InstructionTranslatorBase] = tx
while True:
assert cur_tx is not None
while cur_tx is not None:
# this should have been checked by the caller
assert all(block.can_restore() for block in cur_tx.block_stack)
@ -1352,8 +1349,11 @@ class OutputGraph(OutputGraphGuardsState):
)
all_stack_values.append(stack_values)
all_stack_locals_metas.append(meta)
if cur_tx is self.root_tx:
break
# Exit from all context manager variables to make sure global state is restored
for block in reversed(cur_tx.block_stack):
block.exit(cur_tx, is_graph_break=reason.graph_break)
cur_tx = cur_tx.parent
# "Garbage collect the heap".
@ -1371,10 +1371,6 @@ class OutputGraph(OutputGraphGuardsState):
)
self.add_output_instructions(alias_insts)
# Exit from all context manager variables to make sure global state is restored
for block in reversed(self.root_tx.block_stack):
block.exit(self.root_tx, is_graph_break=reason.graph_break)
self.cleanup_graph()
# Use nn.Module "proxies" in the constructed GraphModule so that
@ -1411,41 +1407,27 @@ class OutputGraph(OutputGraphGuardsState):
)
self.add_output_instructions(random_calls_instructions)
# FIXME: right now not dealing with cells because they're difficult to deal with
# codegen stack convention before the unsupported instruction
# NOTE: in this comment block, "cell" refers to a Python cell object - i.e. free and cell vars
# Codegen stack convention before the unsupported instruction
# NOTE: in these comment blocks, "locals" EXCLUDE free and cell vars.
# NOTE: stack and locals must be codegen'd BEFORE the unsupported instruction, since the latter
# can arbitrarily mutate the former.
# [
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
# frame N locals,
# frame N-1 stack + locals,
# ...,
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
# ], top stack_pops values of frame N
# frame 1 stack + locals,
# ], frame N stack
# see symbolic_convert.py for
# codegen stack convention after the unsupported instruction
# before calling resume function
# NOTE: need to push result of unsupported instruction to frame N stack
# [
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
# ...,
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
# ], frame 1 stack + frame 1 non-cell locals
# NOTE: cells are loaded into continuation functions directly
# (frame 1 cells should be loaded into the continuation function directly
# as part of the closure)
# NOTE: move the top stack_pops values from frame N to the beginning of the flat list.
# This is to prevent packing NULLs into a list.
cur_num_stack = all_stack_locals_metas[0].num_stack
stack_values_flat = (
all_stack_values[0][cur_num_stack - stack_pops : cur_num_stack]
+ all_stack_values[0][: cur_num_stack - stack_pops]
+ all_stack_values[0][cur_num_stack:]
+ [val for vals in all_stack_values[1:] for val in vals]
)
# this determines the order that values are codegen'd to the stack
stack_values_flat = [val for vals in all_stack_values for val in vals]
stored_graph_output_var = False
graph_output_var = None
# call compiled fx graph and codegen everything - stack, locals, cells
# call compiled fx graph and codegen all values - stack and locals
if (
self.root_tx is tx # single frame
and stack_values_flat
@ -1527,94 +1509,87 @@ class OutputGraph(OutputGraphGuardsState):
self.run_compiler_collective()
self.add_output_instructions(output + pass2.get_instructions())
# store all stack, locals, cells for each frame
# store all stack and locals for each frame
# current state of the stack:
# *(top stack_pops values), *(remaining stack_values_flat)
# *(frame N stack), *(frame N locals),
# ...,
# *(frame 1 stack), *(frame 1 locals)
self.add_output_instructions(
[
create_instruction(
"BUILD_LIST", arg=len(stack_values_flat) - stack_pops
"BUILD_LIST",
arg=len(stack_values_flat) - all_stack_locals_metas[0].num_stack,
),
]
)
# iterate current frame to root frame
# sliding window over frame stack/locals/cells
# current state of the stack:
# *(frame N stack), [
# *(frame N locals),
# *(frame N-1 stack), *(frame N-1 locals),
# ...
# *(frame 1 stack), *(frame 1 locals),
# ]
# iterate current frame (N) to root frame (1)
# sliding window over frame stack/locals
start_idx = 0
end_idx = 0
for i, meta in enumerate(all_stack_locals_metas):
# stack, locals, cells
# account for removed stack_pops values in current frame
num_stack = meta.num_stack - stack_pops if i == 0 else meta.num_stack
counts = (
num_stack,
len(meta.locals_names),
# len(meta.cell_and_freevars),
)
self.add_output_instructions([create_dup_top()])
# values, values
for j, cnt in enumerate(counts):
end_idx += cnt
if start_idx == end_idx:
self.add_output_instructions(
[
create_instruction("BUILD_LIST", arg=0),
*create_swap(2),
]
)
# [], values
else:
self.add_output_instructions(
[
create_dup_top(),
*create_binary_slice(start_idx, end_idx),
*create_swap(2),
]
)
# values[x:y], values
# add root frame's unmodified locals here
if i == len(all_stack_locals_metas) - 1 and j == 1:
root_cg = PyCodegen(self.root_tx)
unmodified_locals_names: dict[str, int] = {}
for k, v in self.root_tx.symbolic_locals.items():
if (
isinstance(v.source, LocalSource)
and v.source.local_name == k
):
root_cg.append_output(root_cg.create_load(k))
unmodified_locals_names[k] = len(meta.locals_names) + len(
unmodified_locals_names
)
self.add_output_instructions(
root_cg.get_instructions()
+ [
create_instruction(
"BUILD_LIST", arg=len(unmodified_locals_names)
),
# arg=2 because we already swapped the locals list back
create_instruction("LIST_EXTEND", arg=2),
]
)
meta.locals_names.update(unmodified_locals_names)
start_idx += cnt
# do not pack frame N's stack into the value list
n_vals = len(meta.locals_names)
if i != 0:
n_vals += meta.num_stack
if n_vals == 0:
self.add_output_instructions(
[
create_instruction("BUILD_LIST", arg=0),
*create_swap(2),
]
)
# [], stack_values_flat
else:
end_idx += n_vals
self.add_output_instructions(
[
create_dup_top(),
*create_binary_slice(start_idx, end_idx),
*create_swap(2),
]
)
start_idx += n_vals
# stack_values_flat[x:y], stack_values_flat
# pack stack, locals, cells together
# values, stack, locals, cells, values
self.add_output_instructions(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_TUPLE", arg=2),
*create_swap(2),
]
)
# (stack, locals, cells), values
# add root frame's unmodified locals here
if i == len(all_stack_locals_metas) - 1:
root_cg = PyCodegen(self.root_tx)
unmodified_locals_names: dict[str, int] = {}
for k, v in self.root_tx.symbolic_locals.items():
if isinstance(v.source, LocalSource) and v.source.local_name == k:
root_cg.append_output(root_cg.create_load(k))
unmodified_locals_names[k] = len(meta.locals_names) + len(
unmodified_locals_names
)
self.add_output_instructions(
root_cg.get_instructions()
+ [
create_instruction(
"BUILD_LIST", arg=len(unmodified_locals_names)
),
# arg=2 because we already swapped the locals list back
create_instruction("LIST_EXTEND", arg=2),
]
)
meta.locals_names.update(unmodified_locals_names)
# *(frame N stack), metas[0] stack + locals, ..., metas[i] stack + locals, stack_values_flat
# current state of the stack:
# *(top stack_pops values),
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
# ...,
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
# *(frame N stack)
# frame N locals,
# frame N-1 stack, frame N-1 locals,
# ...
# frame 1 stack, frame 1 locals,
# stack_values_flat
#
@ -1622,16 +1597,17 @@ class OutputGraph(OutputGraphGuardsState):
[
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(all_stack_locals_metas)),
*create_rot_n(stack_pops + 1),
*create_rot_n(all_stack_locals_metas[0].num_stack + 1),
]
)
# final state of the stack before running the unsupported bytecode:
# [
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
# [frame N locals],
# [frame N-1 stack + locals],
# ...,
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
# ], *(top stack_pops values of frame N)
# [frame 1 stack + locals],
# ], *(frame N stack)
if graph_output_var and stored_graph_output_var:
self.add_output_instructions(

View File

@ -25,7 +25,6 @@ from .bytecode_transformation import (
add_push_null,
bytecode_from_template,
create_call_function,
create_dup_top,
create_instruction,
create_jump_absolute,
create_load_const,
@ -491,15 +490,8 @@ class ContinueExecutionCache:
# create [
# __nested_resume_fns,
# __nested_frame_values,
# *__nested_frame_values[-1][0],
# *__nested_frame_values[-1][1]],
# *__nested_frame_values[-1],
# ]
create_dup_top(),
create_instruction("LOAD_CONST", argval=0),
create_instruction("BINARY_SUBSCR"),
create_instruction("LIST_EXTEND", arg=2),
create_instruction("LOAD_CONST", argval=1),
create_instruction("BINARY_SUBSCR"),
create_instruction("LIST_EXTEND", arg=1),
# del __nested_frame_values[-1]
create_instruction("LOAD_FAST", argval="__nested_frame_values"),

View File

@ -72,10 +72,13 @@ from .bytecode_analysis import (
)
from .bytecode_transformation import (
cleaned_instructions,
create_binary_slice,
create_call_function,
create_copy,
create_dup_top,
create_instruction,
create_jump_absolute,
create_rot_n,
create_swap,
get_code_keys,
Instruction,
@ -671,14 +674,12 @@ def generic_jump(
self.pop()
if_next = self.create_call_resume_at(
self.next_instruction, 0, all_stack_locals_metadata
self.next_instruction, all_stack_locals_metadata
)
if push:
self.push(value)
assert inst.target is not None
if_jump = self.create_call_resume_at(
inst.target, int(push), all_stack_locals_metadata
)
if_jump = self.create_call_resume_at(inst.target, all_stack_locals_metadata)
if sys.version_info >= (3, 13):
# 3.13 requires stack[-1] to be bool type
@ -1011,7 +1012,7 @@ def break_graph_if_unsupported(
self.push(UnknownVariable())
self.output.add_output_instructions(
self.create_call_resume_at(
self.next_instruction, push, all_stack_locals_metadata
self.next_instruction, all_stack_locals_metadata
)
)
@ -1426,17 +1427,16 @@ class InstructionTranslatorBase(
# load locals from frame values
# current frame state
# [
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
# frame N locals,
# frame N-1 stack + locals,
# ...,
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
# frame 1 stack + locals,
# ],
cg = PyCodegen(self)
self.output.add_output_instructions(
[
cg.create_load_const(-1),
cg.create_binary_subscr(),
cg.create_load_const(1),
cg.create_binary_subscr(),
]
)
for local, idx in all_stack_locals_metadata[-1].locals_names.items():
@ -2467,9 +2467,7 @@ class InstructionTranslatorBase(
self.output.add_output_instructions([copy.copy(inst)])
self.popn(2)
self.output.add_output_instructions(
self.create_call_resume_at(
self.next_instruction, 0, all_stack_locals_metadata
)
self.create_call_resume_at(self.next_instruction, all_stack_locals_metadata)
)
def DELETE_ATTR(self, inst: Instruction) -> None:
@ -2481,7 +2479,7 @@ class InstructionTranslatorBase(
)
def create_call_resume_at(
self, inst: Instruction, push: int, all_stack_locals_metadata: Any
self, inst: Instruction, all_stack_locals_metadata: Any
) -> list[Instruction]:
self.instruction_pointer = None
@ -2494,38 +2492,35 @@ class InstructionTranslatorBase(
# current frame state
# [
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
# frame N locals,
# frame N-1 stack + locals,
# ...,
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
# ], `push` values from running the unsupported instruction
# frame 1 stack + locals
# ], frame N stack (post-instruction)
# move the `push` stack values to the frame N stack
# move frame N stack to the frame values list
current_num_stack = len(self.stack) - len(
all_stack_locals_metadata[0].stack_null_idxes
)
all_stack_locals_metadata[0].num_stack = current_num_stack
cg.extend_output(
[
create_instruction("BUILD_LIST", arg=push),
# frames_list, push_values_list
*create_swap(2),
create_dup_top(),
create_instruction("BUILD_LIST", arg=current_num_stack),
*create_copy(2),
# frame_values, frame N stack, frame_values
cg.create_load_const(0),
cg.create_binary_subscr(),
cg.create_load_const(0),
cg.create_binary_subscr(),
# push_values_list, frames_list, frames_list[0][0]
*create_swap(3),
# frames_list[0][0] += push_values_list
create_instruction("LIST_EXTEND", arg=2),
*create_swap(2),
# frames_list, frames_list[0][0]
create_instruction("POP_TOP"),
*create_binary_slice(0, 0, True),
# frame_values[0][0:0] = frame N stack
# frame_values left on top of stack
]
)
# current frame state
# [
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
# [frame N stack (fixed) + locals]
# ...,
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
# [frame 1 stack + locals]
# ],
#
@ -2541,12 +2536,11 @@ class InstructionTranslatorBase(
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
# result in silent incorrectness!
argnames: tuple[str, ...] = ()
for i, meta in enumerate(all_stack_locals_metadata):
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
# Replace the stack var with the context class
ctx = cast(ContextWrappingVariable, txes[i].stack[j_orig])
# frames[i][0][j] = reconstructed_ctx
# frames[i][j] = reconstructed_ctx
cg.append_output(create_dup_top())
ctx.reconstruct_type(cg)
cg.extend_output(
@ -2554,8 +2548,6 @@ class InstructionTranslatorBase(
*create_swap(2),
cg.create_load_const(i),
cg.create_binary_subscr(),
cg.create_load_const(0),
cg.create_binary_subscr(),
cg.create_load_const(j),
create_instruction("STORE_SUBSCR"),
]
@ -2564,7 +2556,7 @@ class InstructionTranslatorBase(
for name, _ in meta.locals_ctx_args:
# Replace the local with the context class
ctx = cast(ContextWrappingVariable, txes[i].symbolic_locals[name])
# frames[i][1][meta.locals_names[name]] = reconstructed_ctx
# frames[i][meta.num_stack +meta.locals_names[name]] = reconstructed_ctx
cg.append_output(create_dup_top())
ctx.reconstruct_type(cg)
cg.extend_output(
@ -2572,9 +2564,7 @@ class InstructionTranslatorBase(
*create_swap(2),
cg.create_load_const(i),
cg.create_binary_subscr(),
cg.create_load_const(1),
cg.create_binary_subscr(),
cg.create_load_const(meta.locals_names[name]),
cg.create_load_const(meta.num_stack + meta.locals_names[name]),
create_instruction("STORE_SUBSCR"),
]
)
@ -2595,21 +2585,65 @@ class InstructionTranslatorBase(
if is_jump_absolute(resume_inst):
assert resume_inst.target
resume_inst = resume_inst.target
name = unique_id(f"__resume_at_{resume_inst.offset}")
resume_names.append(name)
resume_name = unique_id(f"__resume_at_{resume_inst.offset}")
resume_names.append(resume_name)
# More locals may have been pruned in the current frame
# after the unsupported instruction (e.g. branch).
# There should not be any pruning in the other frames since
# the current instruction is a CALL.
if cur_tx is self:
reads = livevars_analysis(cur_tx.instructions, resume_inst)
all_argnames = tuple(
k
for k in cur_tx.symbolic_locals.keys()
if k in reads and k not in cur_tx.cell_and_freevars()
)
argnames_null_set = set(meta.locals_null_keys)
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
# codegen filter for current frame's locals
# current stack state: frames
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(i),
cg.create_binary_subscr(),
create_dup_top(),
]
)
for arg in argnames:
# current stack state: frames, frames[i], *(prev locals), frames[i]
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(
meta.num_stack + meta.locals_names[arg]
),
cg.create_binary_subscr(),
*create_swap(2),
],
)
# current stack state: frames, frames[i], *(frame i live locals), frames[i]
cg.extend_output(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(argnames)),
*create_swap(2),
# frames, frames i live locals, frames[i]
*create_binary_slice(meta.num_stack, None, True),
# frames[i][num_stack:] = frame i live locals
]
)
# current stack state: frames
else:
argnames = tuple(meta.locals_names.keys())
argnames_null = tuple(meta.locals_null_keys)
# more locals may have been pruned after the unsupported instruction (e.g. branch)
reads = livevars_analysis(cur_tx.instructions, resume_inst)
all_argnames = tuple(
k
for k in cur_tx.symbolic_locals.keys()
if k in reads and k not in cur_tx.cell_and_freevars()
)
argnames_null_set = set(meta.locals_null_keys)
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
if sys.version_info < (3, 12):
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
# compile_subgraph did not codegen any NULLs,
# so we should not count NullVariables
stack_len = len(cur_tx.stack) - len(meta.stack_null_idxes)
@ -2643,14 +2677,15 @@ class InstructionTranslatorBase(
# add resume function to the global scope
if new_code.co_freevars:
# expose code object for debugging purposes
cur_tx.output.install_global_unsafe(name, new_code)
cur_tx.output.install_global_unsafe(resume_name, new_code)
package_name = None
else:
# This is safe: we pre-generate a unique name
cur_tx.output.install_global_unsafe(
name, types.FunctionType(new_code, cur_tx.f_globals, name)
resume_name,
types.FunctionType(new_code, cur_tx.f_globals, resume_name),
)
package_name = name
package_name = resume_name
if cur_tx.package is not None:
cur_tx.package.add_resume_function(
@ -2687,10 +2722,10 @@ class InstructionTranslatorBase(
# [
# [resume N, ..., resume 2],
# [
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
# frame N stack + locals,
# ...,
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
# ], *(frame 1 stack + frame 1 non-cell locals)
# frame 2 stack + locals,
# ], *(frame 1 stack + locals)
# ]
cg.extend_output(
[
@ -2704,48 +2739,21 @@ class InstructionTranslatorBase(
# frames, frames[-1], frames
cg.create_load_const(-1),
create_instruction("DELETE_SUBSCR"),
# del frames[-1]; stack: frames, frames[-1]
create_dup_top(),
cg.create_load_const(0),
cg.create_binary_subscr(),
# frames, frames[-1], frames[-1][0]
*create_swap(2),
cg.create_load_const(1),
cg.create_binary_subscr(),
]
)
# resumes, frames, frames[-1][0], frames[-1][1]
for name in argnames:
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(
all_stack_locals_metadata[-1].locals_names[name]
),
cg.create_binary_subscr(),
*create_swap(2),
],
)
# resumes, frames, frames[-1][0], *(live locals), frames[-1][1]
# TOS: resumes, frames (popped), frame 1 stack + locals
cg.extend_output(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(argnames)),
*create_swap(4),
# live_locals, frames, frames[-1][0], resumes
create_instruction("BUILD_LIST", arg=1),
*create_swap(3),
# live_locals, [resumes], frames[-1][0], frames
create_instruction("LIST_APPEND", arg=2),
create_instruction("LIST_EXTEND", arg=1),
# live_locals, [resumes, frames, *stack]
*create_rot_n(3),
create_instruction("BUILD_LIST", arg=2),
*create_swap(2),
# [resumes, frames (popped)], frame 1 stack + locals
create_instruction("LIST_EXTEND", arg=1),
]
)
# [resumes, frames, *(stack + live locals)]
# TOS: [resumes, frames, *(frame 1 stack + locals)]
cg.extend_output(
[
create_instruction("CALL_FUNCTION_EX", arg=0),
@ -4391,10 +4399,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
return False # inlining functions is all-or-nothing
def create_call_resume_at(
self, inst: Instruction, push: int, all_stack_locals_metadata: Any
self, inst: Instruction, all_stack_locals_metadata: Any
) -> list[Instruction]:
if config.nested_graph_breaks:
return super().create_call_resume_at(inst, push, all_stack_locals_metadata)
return super().create_call_resume_at(inst, all_stack_locals_metadata)
unimplemented_v2(
gb_type="Graph break in inlined function",
context="",