mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo, nested graph breaks] clean up comments and codegen (#160138)
Fix comments to reflect that we no longer codegen cells to be sent to resume function as inputs - they are instead codegen'd after the unsupported instruction in order to build resume functions that are closures. Also simplify some codegen. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160138 Approved by: https://github.com/anijain2305 ghstack dependencies: #159329, #159678, #159817
This commit is contained in:
parent
d0a242e547
commit
6562646dab
|
|
@ -105,6 +105,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
|||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 6)
|
||||
|
||||
def test_single_graph_break_repeat(self):
|
||||
global f1, f2, f3
|
||||
|
|
@ -129,6 +130,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
|||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
self.assertEqual(cnts.op_count, 10)
|
||||
|
||||
def test_doubly_nested_graph_break(self):
|
||||
global f1, f2, f3
|
||||
|
|
@ -153,6 +155,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
|||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
self.assertEqual(cnts.op_count, 7)
|
||||
|
||||
def test_differing_arg_nums(self):
|
||||
global f1, f2, f3, f4
|
||||
|
|
@ -178,6 +181,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
|||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 10)
|
||||
|
||||
def test_differing_locals_nums(self):
|
||||
global f1, f2, f3
|
||||
|
|
@ -206,6 +210,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
|||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 14)
|
||||
|
||||
def test_supported_ctx_manager(self):
|
||||
global check, check_disabled, f1, f2, f3
|
||||
|
|
@ -248,6 +253,8 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
|||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 4)
|
||||
# includes set_grad_enabled ops
|
||||
self.assertEqual(cnts.op_count, 14)
|
||||
|
||||
def test_inactive_ctx_manager(self):
|
||||
global check, f1, f2, f3
|
||||
|
|
@ -295,6 +302,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
|||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
self.assertEqual(cnts.op_count, 7)
|
||||
|
||||
def test_cells(self):
|
||||
def f1(x1):
|
||||
|
|
@ -327,6 +335,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
|||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 13)
|
||||
|
||||
def test_side_effects_cells(self):
|
||||
cell1, cell2, cell3, cell4 = (torch.zeros(3),) * 4
|
||||
|
|
@ -364,6 +373,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 5)
|
||||
|
||||
def test_side_effects_globals(self):
|
||||
global f1, f2, f3
|
||||
|
|
@ -401,6 +411,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 6)
|
||||
|
||||
def test_side_effects_globals_different_module(self):
|
||||
global f1, f2, _test_nested_graph_breaks_helper
|
||||
|
|
@ -431,6 +442,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 7)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_nested_graph_break_in_loop(self):
|
||||
|
|
|
|||
|
|
@ -508,6 +508,43 @@ def create_binary_slice(
|
|||
]
|
||||
|
||||
|
||||
def create_copy(i: int) -> list[Instruction]:
|
||||
if sys.version_info >= (3, 11):
|
||||
return [create_instruction("COPY", arg=i)]
|
||||
# COPY 4
|
||||
# 0 1 2 3
|
||||
# 3 1 2 0
|
||||
# 3 1 2 0 0
|
||||
# 0 1 2 0 3
|
||||
# 0 1 2 3 0
|
||||
return [
|
||||
*create_swap(i),
|
||||
create_dup_top(),
|
||||
*create_swap(i + 1),
|
||||
*create_swap(2),
|
||||
]
|
||||
|
||||
|
||||
# mainly for debugging generated bytecode
|
||||
def create_print_on_stack(depth: int) -> list[Instruction]:
|
||||
return [
|
||||
*add_push_null(create_instruction("LOAD_CONST", argval=print)),
|
||||
*create_copy(depth + (2 if sys.version_info >= (3, 11) else 1)),
|
||||
*create_call_function(1, False),
|
||||
create_instruction("POP_TOP"),
|
||||
]
|
||||
|
||||
|
||||
# mainly for debugging generated bytecode
|
||||
def create_print_value(value: Any) -> list[Instruction]:
|
||||
return [
|
||||
*add_push_null(create_instruction("LOAD_CONST", argval=print)),
|
||||
create_instruction("LOAD_CONST", argval=value),
|
||||
*create_call_function(1, False),
|
||||
create_instruction("POP_TOP"),
|
||||
]
|
||||
|
||||
|
||||
def lnotab_writer(
|
||||
lineno: int, byteno: int = 0
|
||||
) -> tuple[list[int], Callable[[int, int], None]]:
|
||||
|
|
|
|||
|
|
@ -356,7 +356,6 @@ class StackLocalsMetadata:
|
|||
locals_names: dict[str, int] = dc_field(
|
||||
default_factory=dict
|
||||
) # order of locals codegen'd to the stack
|
||||
cell_and_freevars: dict[str, int] = dc_field(default_factory=dict)
|
||||
stack_null_idxes: list[int] = dc_field(default_factory=list)
|
||||
locals_null_keys: list[str] = dc_field(default_factory=list)
|
||||
stack_ctx_args: list[tuple[int, tuple[Any, ...]]] = dc_field(default_factory=list)
|
||||
|
|
@ -1237,10 +1236,7 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
|
||||
meta.num_stack = len(stack_values)
|
||||
|
||||
cell_and_freevars = dict.fromkeys(tx.cellvars() + tx.freevars())
|
||||
meta.cell_and_freevars = {
|
||||
name: i for i, name in enumerate(cell_and_freevars.keys())
|
||||
}
|
||||
cell_and_freevars = set(tx.cellvars() + tx.freevars())
|
||||
|
||||
# NB: Typically (i.e., for graph compile from RETURN_VALUE),
|
||||
# symbolic_locals will be empty at this point, as prune_dead_locals
|
||||
|
|
@ -1256,7 +1252,8 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
# This will in turn result in spurious variables showing up in the graph.
|
||||
# This was very tricky to debug. For an example, dump the graph at call_user_compiler
|
||||
# while running test_subgraphs.py
|
||||
# Do not load unmodified locals (load them at a later time) from the top frame
|
||||
# Do not include top-frame unmodified locals here - otherwise, the compiled graph may
|
||||
# erroneously include them as part of the return. We manually codegen them afterward.
|
||||
if (
|
||||
isinstance(v.source, LocalSource)
|
||||
and v.source.local_name == k
|
||||
|
|
@ -1264,7 +1261,7 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
):
|
||||
continue
|
||||
# Do not load cell/free vars
|
||||
if k in meta.cell_and_freevars:
|
||||
if k in cell_and_freevars:
|
||||
continue
|
||||
# Do not load variable if it is NULL.
|
||||
if sys.version_info >= (3, 12):
|
||||
|
|
@ -1338,12 +1335,12 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
prefix_insts.append(copy.copy(inst))
|
||||
|
||||
# stack values and restore vars for each frame are pushed in reverse order
|
||||
# i.e. last element corresponds to root frame, first element corresponds to current frame
|
||||
# i.e. last element corresponds to root frame (1),
|
||||
# first element corresponds to current frame (N)
|
||||
all_stack_values = []
|
||||
all_stack_locals_metas = []
|
||||
cur_tx: Optional[InstructionTranslatorBase] = tx
|
||||
while True:
|
||||
assert cur_tx is not None
|
||||
while cur_tx is not None:
|
||||
# this should have been checked by the caller
|
||||
assert all(block.can_restore() for block in cur_tx.block_stack)
|
||||
|
||||
|
|
@ -1352,8 +1349,11 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
)
|
||||
all_stack_values.append(stack_values)
|
||||
all_stack_locals_metas.append(meta)
|
||||
if cur_tx is self.root_tx:
|
||||
break
|
||||
|
||||
# Exit from all context manager variables to make sure global state is restored
|
||||
for block in reversed(cur_tx.block_stack):
|
||||
block.exit(cur_tx, is_graph_break=reason.graph_break)
|
||||
|
||||
cur_tx = cur_tx.parent
|
||||
|
||||
# "Garbage collect the heap".
|
||||
|
|
@ -1371,10 +1371,6 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
)
|
||||
self.add_output_instructions(alias_insts)
|
||||
|
||||
# Exit from all context manager variables to make sure global state is restored
|
||||
for block in reversed(self.root_tx.block_stack):
|
||||
block.exit(self.root_tx, is_graph_break=reason.graph_break)
|
||||
|
||||
self.cleanup_graph()
|
||||
|
||||
# Use nn.Module "proxies" in the constructed GraphModule so that
|
||||
|
|
@ -1411,41 +1407,27 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
)
|
||||
self.add_output_instructions(random_calls_instructions)
|
||||
|
||||
# FIXME: right now not dealing with cells because they're difficult to deal with
|
||||
# codegen stack convention before the unsupported instruction
|
||||
# NOTE: in this comment block, "cell" refers to a Python cell object - i.e. free and cell vars
|
||||
# Codegen stack convention before the unsupported instruction
|
||||
# NOTE: in these comment blocks, "locals" EXCLUDE free and cell vars.
|
||||
# NOTE: stack and locals must be codegen'd BEFORE the unsupported instruction, since the latter
|
||||
# can arbitrarily mutate the former.
|
||||
# [
|
||||
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
|
||||
# frame N locals,
|
||||
# frame N-1 stack + locals,
|
||||
# ...,
|
||||
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
|
||||
# ], top stack_pops values of frame N
|
||||
# frame 1 stack + locals,
|
||||
# ], frame N stack
|
||||
|
||||
# see symbolic_convert.py for
|
||||
# codegen stack convention after the unsupported instruction
|
||||
# before calling resume function
|
||||
# NOTE: need to push result of unsupported instruction to frame N stack
|
||||
# [
|
||||
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
|
||||
# ...,
|
||||
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
|
||||
# ], frame 1 stack + frame 1 non-cell locals
|
||||
# NOTE: cells are loaded into continuation functions directly
|
||||
|
||||
# (frame 1 cells should be loaded into the continuation function directly
|
||||
# as part of the closure)
|
||||
|
||||
# NOTE: move the top stack_pops values from frame N to the beginning of the flat list.
|
||||
# This is to prevent packing NULLs into a list.
|
||||
|
||||
cur_num_stack = all_stack_locals_metas[0].num_stack
|
||||
stack_values_flat = (
|
||||
all_stack_values[0][cur_num_stack - stack_pops : cur_num_stack]
|
||||
+ all_stack_values[0][: cur_num_stack - stack_pops]
|
||||
+ all_stack_values[0][cur_num_stack:]
|
||||
+ [val for vals in all_stack_values[1:] for val in vals]
|
||||
)
|
||||
# this determines the order that values are codegen'd to the stack
|
||||
stack_values_flat = [val for vals in all_stack_values for val in vals]
|
||||
stored_graph_output_var = False
|
||||
graph_output_var = None
|
||||
|
||||
# call compiled fx graph and codegen everything - stack, locals, cells
|
||||
# call compiled fx graph and codegen all values - stack and locals
|
||||
if (
|
||||
self.root_tx is tx # single frame
|
||||
and stack_values_flat
|
||||
|
|
@ -1527,94 +1509,87 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
self.run_compiler_collective()
|
||||
self.add_output_instructions(output + pass2.get_instructions())
|
||||
|
||||
# store all stack, locals, cells for each frame
|
||||
# store all stack and locals for each frame
|
||||
# current state of the stack:
|
||||
# *(top stack_pops values), *(remaining stack_values_flat)
|
||||
# *(frame N stack), *(frame N locals),
|
||||
# ...,
|
||||
# *(frame 1 stack), *(frame 1 locals)
|
||||
|
||||
self.add_output_instructions(
|
||||
[
|
||||
create_instruction(
|
||||
"BUILD_LIST", arg=len(stack_values_flat) - stack_pops
|
||||
"BUILD_LIST",
|
||||
arg=len(stack_values_flat) - all_stack_locals_metas[0].num_stack,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# iterate current frame to root frame
|
||||
# sliding window over frame stack/locals/cells
|
||||
# current state of the stack:
|
||||
# *(frame N stack), [
|
||||
# *(frame N locals),
|
||||
# *(frame N-1 stack), *(frame N-1 locals),
|
||||
# ...
|
||||
# *(frame 1 stack), *(frame 1 locals),
|
||||
# ]
|
||||
# iterate current frame (N) to root frame (1)
|
||||
# sliding window over frame stack/locals
|
||||
start_idx = 0
|
||||
end_idx = 0
|
||||
for i, meta in enumerate(all_stack_locals_metas):
|
||||
# stack, locals, cells
|
||||
# account for removed stack_pops values in current frame
|
||||
num_stack = meta.num_stack - stack_pops if i == 0 else meta.num_stack
|
||||
counts = (
|
||||
num_stack,
|
||||
len(meta.locals_names),
|
||||
# len(meta.cell_and_freevars),
|
||||
)
|
||||
self.add_output_instructions([create_dup_top()])
|
||||
# values, values
|
||||
for j, cnt in enumerate(counts):
|
||||
end_idx += cnt
|
||||
if start_idx == end_idx:
|
||||
self.add_output_instructions(
|
||||
[
|
||||
create_instruction("BUILD_LIST", arg=0),
|
||||
*create_swap(2),
|
||||
]
|
||||
)
|
||||
# [], values
|
||||
else:
|
||||
self.add_output_instructions(
|
||||
[
|
||||
create_dup_top(),
|
||||
*create_binary_slice(start_idx, end_idx),
|
||||
*create_swap(2),
|
||||
]
|
||||
)
|
||||
# values[x:y], values
|
||||
# add root frame's unmodified locals here
|
||||
if i == len(all_stack_locals_metas) - 1 and j == 1:
|
||||
root_cg = PyCodegen(self.root_tx)
|
||||
unmodified_locals_names: dict[str, int] = {}
|
||||
for k, v in self.root_tx.symbolic_locals.items():
|
||||
if (
|
||||
isinstance(v.source, LocalSource)
|
||||
and v.source.local_name == k
|
||||
):
|
||||
root_cg.append_output(root_cg.create_load(k))
|
||||
unmodified_locals_names[k] = len(meta.locals_names) + len(
|
||||
unmodified_locals_names
|
||||
)
|
||||
self.add_output_instructions(
|
||||
root_cg.get_instructions()
|
||||
+ [
|
||||
create_instruction(
|
||||
"BUILD_LIST", arg=len(unmodified_locals_names)
|
||||
),
|
||||
# arg=2 because we already swapped the locals list back
|
||||
create_instruction("LIST_EXTEND", arg=2),
|
||||
]
|
||||
)
|
||||
meta.locals_names.update(unmodified_locals_names)
|
||||
start_idx += cnt
|
||||
# do not pack frame N's stack into the value list
|
||||
n_vals = len(meta.locals_names)
|
||||
if i != 0:
|
||||
n_vals += meta.num_stack
|
||||
if n_vals == 0:
|
||||
self.add_output_instructions(
|
||||
[
|
||||
create_instruction("BUILD_LIST", arg=0),
|
||||
*create_swap(2),
|
||||
]
|
||||
)
|
||||
# [], stack_values_flat
|
||||
else:
|
||||
end_idx += n_vals
|
||||
self.add_output_instructions(
|
||||
[
|
||||
create_dup_top(),
|
||||
*create_binary_slice(start_idx, end_idx),
|
||||
*create_swap(2),
|
||||
]
|
||||
)
|
||||
start_idx += n_vals
|
||||
# stack_values_flat[x:y], stack_values_flat
|
||||
|
||||
# pack stack, locals, cells together
|
||||
# values, stack, locals, cells, values
|
||||
self.add_output_instructions(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_TUPLE", arg=2),
|
||||
*create_swap(2),
|
||||
]
|
||||
)
|
||||
# (stack, locals, cells), values
|
||||
# add root frame's unmodified locals here
|
||||
if i == len(all_stack_locals_metas) - 1:
|
||||
root_cg = PyCodegen(self.root_tx)
|
||||
unmodified_locals_names: dict[str, int] = {}
|
||||
for k, v in self.root_tx.symbolic_locals.items():
|
||||
if isinstance(v.source, LocalSource) and v.source.local_name == k:
|
||||
root_cg.append_output(root_cg.create_load(k))
|
||||
unmodified_locals_names[k] = len(meta.locals_names) + len(
|
||||
unmodified_locals_names
|
||||
)
|
||||
self.add_output_instructions(
|
||||
root_cg.get_instructions()
|
||||
+ [
|
||||
create_instruction(
|
||||
"BUILD_LIST", arg=len(unmodified_locals_names)
|
||||
),
|
||||
# arg=2 because we already swapped the locals list back
|
||||
create_instruction("LIST_EXTEND", arg=2),
|
||||
]
|
||||
)
|
||||
meta.locals_names.update(unmodified_locals_names)
|
||||
|
||||
# *(frame N stack), metas[0] stack + locals, ..., metas[i] stack + locals, stack_values_flat
|
||||
|
||||
# current state of the stack:
|
||||
# *(top stack_pops values),
|
||||
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
|
||||
# ...,
|
||||
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
|
||||
# *(frame N stack)
|
||||
# frame N locals,
|
||||
# frame N-1 stack, frame N-1 locals,
|
||||
# ...
|
||||
# frame 1 stack, frame 1 locals,
|
||||
# stack_values_flat
|
||||
#
|
||||
|
||||
|
|
@ -1622,16 +1597,17 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_LIST", arg=len(all_stack_locals_metas)),
|
||||
*create_rot_n(stack_pops + 1),
|
||||
*create_rot_n(all_stack_locals_metas[0].num_stack + 1),
|
||||
]
|
||||
)
|
||||
|
||||
# final state of the stack before running the unsupported bytecode:
|
||||
# [
|
||||
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
|
||||
# [frame N locals],
|
||||
# [frame N-1 stack + locals],
|
||||
# ...,
|
||||
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
|
||||
# ], *(top stack_pops values of frame N)
|
||||
# [frame 1 stack + locals],
|
||||
# ], *(frame N stack)
|
||||
|
||||
if graph_output_var and stored_graph_output_var:
|
||||
self.add_output_instructions(
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ from .bytecode_transformation import (
|
|||
add_push_null,
|
||||
bytecode_from_template,
|
||||
create_call_function,
|
||||
create_dup_top,
|
||||
create_instruction,
|
||||
create_jump_absolute,
|
||||
create_load_const,
|
||||
|
|
@ -491,15 +490,8 @@ class ContinueExecutionCache:
|
|||
# create [
|
||||
# __nested_resume_fns,
|
||||
# __nested_frame_values,
|
||||
# *__nested_frame_values[-1][0],
|
||||
# *__nested_frame_values[-1][1]],
|
||||
# *__nested_frame_values[-1],
|
||||
# ]
|
||||
create_dup_top(),
|
||||
create_instruction("LOAD_CONST", argval=0),
|
||||
create_instruction("BINARY_SUBSCR"),
|
||||
create_instruction("LIST_EXTEND", arg=2),
|
||||
create_instruction("LOAD_CONST", argval=1),
|
||||
create_instruction("BINARY_SUBSCR"),
|
||||
create_instruction("LIST_EXTEND", arg=1),
|
||||
# del __nested_frame_values[-1]
|
||||
create_instruction("LOAD_FAST", argval="__nested_frame_values"),
|
||||
|
|
|
|||
|
|
@ -72,10 +72,13 @@ from .bytecode_analysis import (
|
|||
)
|
||||
from .bytecode_transformation import (
|
||||
cleaned_instructions,
|
||||
create_binary_slice,
|
||||
create_call_function,
|
||||
create_copy,
|
||||
create_dup_top,
|
||||
create_instruction,
|
||||
create_jump_absolute,
|
||||
create_rot_n,
|
||||
create_swap,
|
||||
get_code_keys,
|
||||
Instruction,
|
||||
|
|
@ -671,14 +674,12 @@ def generic_jump(
|
|||
self.pop()
|
||||
|
||||
if_next = self.create_call_resume_at(
|
||||
self.next_instruction, 0, all_stack_locals_metadata
|
||||
self.next_instruction, all_stack_locals_metadata
|
||||
)
|
||||
if push:
|
||||
self.push(value)
|
||||
assert inst.target is not None
|
||||
if_jump = self.create_call_resume_at(
|
||||
inst.target, int(push), all_stack_locals_metadata
|
||||
)
|
||||
if_jump = self.create_call_resume_at(inst.target, all_stack_locals_metadata)
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
# 3.13 requires stack[-1] to be bool type
|
||||
|
|
@ -1011,7 +1012,7 @@ def break_graph_if_unsupported(
|
|||
self.push(UnknownVariable())
|
||||
self.output.add_output_instructions(
|
||||
self.create_call_resume_at(
|
||||
self.next_instruction, push, all_stack_locals_metadata
|
||||
self.next_instruction, all_stack_locals_metadata
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1426,17 +1427,16 @@ class InstructionTranslatorBase(
|
|||
# load locals from frame values
|
||||
# current frame state
|
||||
# [
|
||||
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
|
||||
# frame N locals,
|
||||
# frame N-1 stack + locals,
|
||||
# ...,
|
||||
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
|
||||
# frame 1 stack + locals,
|
||||
# ],
|
||||
cg = PyCodegen(self)
|
||||
self.output.add_output_instructions(
|
||||
[
|
||||
cg.create_load_const(-1),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(1),
|
||||
cg.create_binary_subscr(),
|
||||
]
|
||||
)
|
||||
for local, idx in all_stack_locals_metadata[-1].locals_names.items():
|
||||
|
|
@ -2467,9 +2467,7 @@ class InstructionTranslatorBase(
|
|||
self.output.add_output_instructions([copy.copy(inst)])
|
||||
self.popn(2)
|
||||
self.output.add_output_instructions(
|
||||
self.create_call_resume_at(
|
||||
self.next_instruction, 0, all_stack_locals_metadata
|
||||
)
|
||||
self.create_call_resume_at(self.next_instruction, all_stack_locals_metadata)
|
||||
)
|
||||
|
||||
def DELETE_ATTR(self, inst: Instruction) -> None:
|
||||
|
|
@ -2481,7 +2479,7 @@ class InstructionTranslatorBase(
|
|||
)
|
||||
|
||||
def create_call_resume_at(
|
||||
self, inst: Instruction, push: int, all_stack_locals_metadata: Any
|
||||
self, inst: Instruction, all_stack_locals_metadata: Any
|
||||
) -> list[Instruction]:
|
||||
self.instruction_pointer = None
|
||||
|
||||
|
|
@ -2494,38 +2492,35 @@ class InstructionTranslatorBase(
|
|||
|
||||
# current frame state
|
||||
# [
|
||||
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
|
||||
# frame N locals,
|
||||
# frame N-1 stack + locals,
|
||||
# ...,
|
||||
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
|
||||
# ], `push` values from running the unsupported instruction
|
||||
# frame 1 stack + locals
|
||||
# ], frame N stack (post-instruction)
|
||||
|
||||
# move the `push` stack values to the frame N stack
|
||||
# move frame N stack to the frame values list
|
||||
current_num_stack = len(self.stack) - len(
|
||||
all_stack_locals_metadata[0].stack_null_idxes
|
||||
)
|
||||
all_stack_locals_metadata[0].num_stack = current_num_stack
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("BUILD_LIST", arg=push),
|
||||
# frames_list, push_values_list
|
||||
*create_swap(2),
|
||||
create_dup_top(),
|
||||
create_instruction("BUILD_LIST", arg=current_num_stack),
|
||||
*create_copy(2),
|
||||
# frame_values, frame N stack, frame_values
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
# push_values_list, frames_list, frames_list[0][0]
|
||||
*create_swap(3),
|
||||
# frames_list[0][0] += push_values_list
|
||||
create_instruction("LIST_EXTEND", arg=2),
|
||||
*create_swap(2),
|
||||
# frames_list, frames_list[0][0]
|
||||
create_instruction("POP_TOP"),
|
||||
*create_binary_slice(0, 0, True),
|
||||
# frame_values[0][0:0] = frame N stack
|
||||
# frame_values left on top of stack
|
||||
]
|
||||
)
|
||||
|
||||
# current frame state
|
||||
# [
|
||||
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
|
||||
# [frame N stack (fixed) + locals]
|
||||
# ...,
|
||||
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
|
||||
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
|
||||
# [frame 1 stack + locals]
|
||||
# ],
|
||||
|
||||
#
|
||||
|
|
@ -2541,12 +2536,11 @@ class InstructionTranslatorBase(
|
|||
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
|
||||
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
|
||||
# result in silent incorrectness!
|
||||
argnames: tuple[str, ...] = ()
|
||||
for i, meta in enumerate(all_stack_locals_metadata):
|
||||
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
|
||||
# Replace the stack var with the context class
|
||||
ctx = cast(ContextWrappingVariable, txes[i].stack[j_orig])
|
||||
# frames[i][0][j] = reconstructed_ctx
|
||||
# frames[i][j] = reconstructed_ctx
|
||||
cg.append_output(create_dup_top())
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(
|
||||
|
|
@ -2554,8 +2548,6 @@ class InstructionTranslatorBase(
|
|||
*create_swap(2),
|
||||
cg.create_load_const(i),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(j),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
|
|
@ -2564,7 +2556,7 @@ class InstructionTranslatorBase(
|
|||
for name, _ in meta.locals_ctx_args:
|
||||
# Replace the local with the context class
|
||||
ctx = cast(ContextWrappingVariable, txes[i].symbolic_locals[name])
|
||||
# frames[i][1][meta.locals_names[name]] = reconstructed_ctx
|
||||
# frames[i][meta.num_stack +meta.locals_names[name]] = reconstructed_ctx
|
||||
cg.append_output(create_dup_top())
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(
|
||||
|
|
@ -2572,9 +2564,7 @@ class InstructionTranslatorBase(
|
|||
*create_swap(2),
|
||||
cg.create_load_const(i),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(1),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(meta.locals_names[name]),
|
||||
cg.create_load_const(meta.num_stack + meta.locals_names[name]),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
|
@ -2595,21 +2585,65 @@ class InstructionTranslatorBase(
|
|||
if is_jump_absolute(resume_inst):
|
||||
assert resume_inst.target
|
||||
resume_inst = resume_inst.target
|
||||
name = unique_id(f"__resume_at_{resume_inst.offset}")
|
||||
resume_names.append(name)
|
||||
resume_name = unique_id(f"__resume_at_{resume_inst.offset}")
|
||||
resume_names.append(resume_name)
|
||||
|
||||
# More locals may have been pruned in the current frame
|
||||
# after the unsupported instruction (e.g. branch).
|
||||
# There should not be any pruning in the other frames since
|
||||
# the current instruction is a CALL.
|
||||
if cur_tx is self:
|
||||
reads = livevars_analysis(cur_tx.instructions, resume_inst)
|
||||
all_argnames = tuple(
|
||||
k
|
||||
for k in cur_tx.symbolic_locals.keys()
|
||||
if k in reads and k not in cur_tx.cell_and_freevars()
|
||||
)
|
||||
argnames_null_set = set(meta.locals_null_keys)
|
||||
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
|
||||
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
|
||||
|
||||
# codegen filter for current frame's locals
|
||||
# current stack state: frames
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(i),
|
||||
cg.create_binary_subscr(),
|
||||
create_dup_top(),
|
||||
]
|
||||
)
|
||||
for arg in argnames:
|
||||
# current stack state: frames, frames[i], *(prev locals), frames[i]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(
|
||||
meta.num_stack + meta.locals_names[arg]
|
||||
),
|
||||
cg.create_binary_subscr(),
|
||||
*create_swap(2),
|
||||
],
|
||||
)
|
||||
# current stack state: frames, frames[i], *(frame i live locals), frames[i]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_LIST", arg=len(argnames)),
|
||||
*create_swap(2),
|
||||
# frames, frames i live locals, frames[i]
|
||||
*create_binary_slice(meta.num_stack, None, True),
|
||||
# frames[i][num_stack:] = frame i live locals
|
||||
]
|
||||
)
|
||||
# current stack state: frames
|
||||
else:
|
||||
argnames = tuple(meta.locals_names.keys())
|
||||
argnames_null = tuple(meta.locals_null_keys)
|
||||
|
||||
# more locals may have been pruned after the unsupported instruction (e.g. branch)
|
||||
reads = livevars_analysis(cur_tx.instructions, resume_inst)
|
||||
all_argnames = tuple(
|
||||
k
|
||||
for k in cur_tx.symbolic_locals.keys()
|
||||
if k in reads and k not in cur_tx.cell_and_freevars()
|
||||
)
|
||||
argnames_null_set = set(meta.locals_null_keys)
|
||||
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
|
||||
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
|
||||
if sys.version_info < (3, 12):
|
||||
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
|
||||
|
||||
# compile_subgraph did not codegen any NULLs,
|
||||
# so we should not count NullVariables
|
||||
stack_len = len(cur_tx.stack) - len(meta.stack_null_idxes)
|
||||
|
|
@ -2643,14 +2677,15 @@ class InstructionTranslatorBase(
|
|||
# add resume function to the global scope
|
||||
if new_code.co_freevars:
|
||||
# expose code object for debugging purposes
|
||||
cur_tx.output.install_global_unsafe(name, new_code)
|
||||
cur_tx.output.install_global_unsafe(resume_name, new_code)
|
||||
package_name = None
|
||||
else:
|
||||
# This is safe: we pre-generate a unique name
|
||||
cur_tx.output.install_global_unsafe(
|
||||
name, types.FunctionType(new_code, cur_tx.f_globals, name)
|
||||
resume_name,
|
||||
types.FunctionType(new_code, cur_tx.f_globals, resume_name),
|
||||
)
|
||||
package_name = name
|
||||
package_name = resume_name
|
||||
|
||||
if cur_tx.package is not None:
|
||||
cur_tx.package.add_resume_function(
|
||||
|
|
@ -2687,10 +2722,10 @@ class InstructionTranslatorBase(
|
|||
# [
|
||||
# [resume N, ..., resume 2],
|
||||
# [
|
||||
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
|
||||
# frame N stack + locals,
|
||||
# ...,
|
||||
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
|
||||
# ], *(frame 1 stack + frame 1 non-cell locals)
|
||||
# frame 2 stack + locals,
|
||||
# ], *(frame 1 stack + locals)
|
||||
# ]
|
||||
cg.extend_output(
|
||||
[
|
||||
|
|
@ -2704,48 +2739,21 @@ class InstructionTranslatorBase(
|
|||
# frames, frames[-1], frames
|
||||
cg.create_load_const(-1),
|
||||
create_instruction("DELETE_SUBSCR"),
|
||||
# del frames[-1]; stack: frames, frames[-1]
|
||||
create_dup_top(),
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
# frames, frames[-1], frames[-1][0]
|
||||
*create_swap(2),
|
||||
cg.create_load_const(1),
|
||||
cg.create_binary_subscr(),
|
||||
]
|
||||
)
|
||||
|
||||
# resumes, frames, frames[-1][0], frames[-1][1]
|
||||
for name in argnames:
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(
|
||||
all_stack_locals_metadata[-1].locals_names[name]
|
||||
),
|
||||
cg.create_binary_subscr(),
|
||||
*create_swap(2),
|
||||
],
|
||||
)
|
||||
# resumes, frames, frames[-1][0], *(live locals), frames[-1][1]
|
||||
# TOS: resumes, frames (popped), frame 1 stack + locals
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_LIST", arg=len(argnames)),
|
||||
*create_swap(4),
|
||||
# live_locals, frames, frames[-1][0], resumes
|
||||
create_instruction("BUILD_LIST", arg=1),
|
||||
*create_swap(3),
|
||||
# live_locals, [resumes], frames[-1][0], frames
|
||||
create_instruction("LIST_APPEND", arg=2),
|
||||
create_instruction("LIST_EXTEND", arg=1),
|
||||
# live_locals, [resumes, frames, *stack]
|
||||
*create_rot_n(3),
|
||||
create_instruction("BUILD_LIST", arg=2),
|
||||
*create_swap(2),
|
||||
# [resumes, frames (popped)], frame 1 stack + locals
|
||||
create_instruction("LIST_EXTEND", arg=1),
|
||||
]
|
||||
)
|
||||
# [resumes, frames, *(stack + live locals)]
|
||||
|
||||
# TOS: [resumes, frames, *(frame 1 stack + locals)]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("CALL_FUNCTION_EX", arg=0),
|
||||
|
|
@ -4391,10 +4399,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
return False # inlining functions is all-or-nothing
|
||||
|
||||
def create_call_resume_at(
|
||||
self, inst: Instruction, push: int, all_stack_locals_metadata: Any
|
||||
self, inst: Instruction, all_stack_locals_metadata: Any
|
||||
) -> list[Instruction]:
|
||||
if config.nested_graph_breaks:
|
||||
return super().create_call_resume_at(inst, push, all_stack_locals_metadata)
|
||||
return super().create_call_resume_at(inst, all_stack_locals_metadata)
|
||||
unimplemented_v2(
|
||||
gb_type="Graph break in inlined function",
|
||||
context="",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user