[dynamo] refactor resume_execution.py to use bytecode templates (#136483)

Use bytecode from template instead of hardcoding bytecode in resume_execution.py. Gets rid of a lot of Python-version dependent bytecode generation. Also makes resume_execution.py easier to support in future Python version updates.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136483
Approved by: https://github.com/jansel, https://github.com/anijain2305
This commit is contained in:
William Wen 2024-09-24 17:08:37 +00:00 committed by PyTorch MergeBot
parent 36f0e61166
commit ae80bce496
3 changed files with 149 additions and 325 deletions

View File

@ -933,6 +933,32 @@ def strip_extended_args(instructions: List[Instruction]) -> None:
instructions[:] = [i for i in instructions if i.opcode != dis.EXTENDED_ARG]
# Overwrites old_inst with a sequence of new instructions.
# This is necessary in order to preserve jump targets to the old
# instruction, exception table entries, and positions.
# Returns the modified sequence of instructions (including the modified
# old instruction!) that can be manipulated elsewhere.
def overwrite_instruction(old_inst, new_insts):
# update old_inst.exnt_tab_entry.end if necessary
if (
old_inst.exn_tab_entry
and old_inst.exn_tab_entry.end is old_inst
and len(new_insts) > 1
):
old_inst.exn_tab_entry.end = new_insts[-1]
# preserve exception table entries and positions
for inst in new_insts[1:]:
inst.exn_tab_entry = copy.copy(old_inst.exn_tab_entry)
inst.positions = old_inst.positions
# modify old_inst in-place to preserve jump target
old_inst.opcode = new_insts[0].opcode
old_inst.opname = new_insts[0].opname
old_inst.arg = new_insts[0].arg
old_inst.argval = new_insts[0].argval
old_inst.target = new_insts[0].target
return [old_inst] + new_insts[1:]
def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction]:
"""LOAD_METHOD puts a NULL on the stack which causes issues, so remove it"""
assert sys.version_info < (3, 11)
@ -947,11 +973,11 @@ def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction
def remove_jump_if_none(instructions: List[Instruction]) -> None:
new_insts = []
for inst in instructions:
new_insts.append(inst)
if "_NONE" in inst.opname:
is_op = create_instruction("IS_OP", arg=int("NOT" in inst.opname))
# need both argval and arg set correctly now (not later)
is_op.argval = is_op.arg
is_op.positions = inst.positions
if sys.version_info < (3, 12):
jump_op = create_instruction(
"POP_JUMP_FORWARD_IF_TRUE"
@ -961,19 +987,15 @@ def remove_jump_if_none(instructions: List[Instruction]) -> None:
)
else:
jump_op = create_instruction("POP_JUMP_IF_TRUE", target=inst.target)
jump_op.positions = inst.positions
# update inst.exn_tab_entry.end if necessary
if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
inst.exn_tab_entry.end = jump_op
# preserve exception table entries
is_op.exn_tab_entry = copy.copy(inst.exn_tab_entry)
jump_op.exn_tab_entry = copy.copy(inst.exn_tab_entry)
# modify inst in-place to preserve jump target
inst.opcode = dis.opmap["LOAD_CONST"]
inst.opname = "LOAD_CONST"
inst.arg = None
inst.argval = None
new_insts.extend([is_op, jump_op])
replace_insts = [
create_instruction("LOAD_CONST", argval=None),
is_op,
jump_op,
]
new_insts.extend(overwrite_instruction(inst, replace_insts))
else:
new_insts.append(inst)
instructions[:] = new_insts
@ -1007,24 +1029,17 @@ FUSED_INSTS = {
def remove_fused_load_store(instructions: List[Instruction]) -> None:
new_insts = []
for inst in instructions:
new_insts.append(inst)
if inst.opname in FUSED_INSTS:
inst0, inst1 = FUSED_INSTS[inst.opname]
argval0, argval1 = inst.argval
# modify inst in-place to preserve jump target
inst.opcode = dis.opmap[inst0]
inst.opname = inst0
inst.argval = argval0
new_inst = create_instruction(inst1, argval=argval1)
# update inst.exn_tab_entry.end if necessary
if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
inst.exn_tab_entry.end = new_inst
# preserve exception table entries
new_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry)
new_insts.append(new_inst)
replace_insts = [
create_instruction(inst0, argval=argval0),
create_instruction(inst1, argval=argval1),
]
new_insts.append(overwrite_instruction(inst, replace_insts))
else:
new_insts.append(inst)
instructions[:] = new_insts
@ -1435,7 +1450,9 @@ def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True):
For example, local variables in `fn` can be replaced with
new names that are generated by `OutputGraph.new_var`.
noreturn: remove all RETURN_* bytecodes and replace them with a jump
to the end of the bytecode.
to the end of the bytecode. NOTE: any items pushed to the stack
for return WILL remain on the stack! Append a POP_TOP if you don't want
that item to be present.
noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive).
"""
insts = cleaned_instructions(fn.__code__)

View File

@ -6,15 +6,12 @@ import types
from typing import Any, cast, Dict, List, Optional, Tuple
from .bytecode_transformation import (
add_push_null,
bytecode_from_template,
create_call_function,
create_call_method,
create_dup_top,
create_instruction,
create_jump_absolute,
create_load_method,
Instruction,
InstructionExnTabEntry,
overwrite_instruction,
transform_code_object,
unique_id,
)
@ -44,6 +41,50 @@ def _initial_push_null(insts):
insts.append(create_instruction("SWAP", arg=2))
# Generates bytecode from template and splits the code where LOAD_FAST dummy is present.
def _bytecode_from_template_with_split(template, stack_index, varname_map=None):
template_code = bytecode_from_template(template, varname_map=varname_map)
template_code.append(create_instruction("POP_TOP"))
# adjust exception table entry depth
for inst in template_code:
if inst.exn_tab_entry:
inst.exn_tab_entry.depth += stack_index
# search for LOAD_FAST dummy and replace it with 2 NOPs (we can break up the bytecode between them)
dummy_idx, dummy_inst = next(
(
(i, inst)
for i, inst in enumerate(template_code)
if inst.opname == "LOAD_FAST" and inst.argval == "dummy"
),
(None, None),
)
assert dummy_idx is not None
# replace LOAD_FAST dummy with first NOP marking exception area
overwrite_instruction(dummy_inst, [create_instruction("NOP")])
# POP_TOP follows LOAD_FAST dummy - replace with NOP marking end of exception area
assert template_code[dummy_idx + 1].opname == "POP_TOP"
overwrite_instruction(template_code[dummy_idx + 1], [create_instruction("NOP")])
return template_code[: dummy_idx + 1], template_code[dummy_idx + 1 :]
def _try_except_tf_mode_template(dummy, stack_var_name):
# NOTE: Make sure this name matches what is generated by symbolic_convert:import_source
# on torch._dynamo.utils.
global __import_torch_dot__dynamo_dot_utils
try:
dummy
except: # noqa: E722, B001
__import_torch_dot__dynamo_dot_utils.set_torch_function_mode_stack( # type: ignore[name-defined]
stack_var_name
)
raise
@dataclasses.dataclass(frozen=True)
class ReenterWith:
stack_index: int
@ -55,106 +96,24 @@ class ReenterWith:
try:
(rest)
except:
(restore previous stack)
(restore previous tf mode stack)
raise
"""
from .variables.torch_function import get_prev_stack_var_name
except_jump_target = create_instruction(
"NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO"
setup_try_except, epilogue = _bytecode_from_template_with_split(
_try_except_tf_mode_template,
self.stack_index,
varname_map={"stack_var_name": get_prev_stack_var_name()},
)
cleanup_complete_jump_target = create_instruction("NOP")
setup_finally: List[Instruction] = []
if sys.version_info < (3, 11):
setup_finally.append(
create_instruction("SETUP_FINALLY", target=except_jump_target)
)
else:
exn_tab_begin = create_instruction("NOP")
exn_tab_end = create_instruction("NOP")
exn_tab_begin.exn_tab_entry = InstructionExnTabEntry(
exn_tab_begin,
exn_tab_end,
except_jump_target,
self.stack_index + 1,
False,
)
setup_finally.append(exn_tab_begin)
def create_reset():
insts = [
create_instruction(
"LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils"
),
create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"),
]
add_push_null(insts)
return [
*insts,
create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()),
*create_call_function(1, False),
create_instruction("POP_TOP"),
]
if sys.version_info < (3, 9):
epilogue = [
create_instruction("POP_BLOCK"),
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
except_jump_target,
*create_reset(),
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
*create_reset(),
create_instruction("RAISE_VARARGS", argval=0),
create_instruction("POP_EXCEPT", argval=0),
create_instruction("END_FINALLY"),
cleanup_complete_jump_target,
]
elif sys.version_info < (3, 11):
epilogue = [
create_instruction("POP_BLOCK"),
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
except_jump_target,
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
*create_reset(),
create_instruction("RAISE_VARARGS", argval=0),
create_instruction("POP_EXCEPT", argval=0),
cleanup_complete_jump_target,
]
else:
finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0)
finally_exn_tab_target = create_instruction("COPY", arg=3)
except_jump_target.exn_tab_entry = InstructionExnTabEntry(
except_jump_target,
finally_exn_tab_end,
finally_exn_tab_target,
self.stack_index + 2,
True,
)
epilogue = [
exn_tab_end,
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
except_jump_target, # PUSH_EXC_INFO
create_instruction("POP_TOP"),
*create_reset(),
finally_exn_tab_end,
finally_exn_tab_target, # COPY 3
create_instruction("POP_EXCEPT"),
create_instruction("RERAISE", arg=1), # RERAISE 1
cleanup_complete_jump_target,
]
cleanup[:] = epilogue + cleanup
return setup_finally
return setup_try_except
# If we do not want to destroy the stack, we can do the same thing as a
# `SETUP_WITH` block, only that we store the context manager in a local_symbol
def try_except(self, code_options, cleanup: List[Instruction]):
def try_finally(self, code_options, cleanup: List[Instruction]):
"""
Codegen based off of:
load args
@ -178,97 +137,28 @@ class ReenterWith:
if name not in code_options["co_names"]:
code_options["co_names"] += (name,)
except_jump_target = create_instruction(
"NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO"
)
cleanup_complete_jump_target = create_instruction("NOP")
setup_finally: List[Instruction] = []
_initial_push_null(setup_finally)
# TODO(williamwen42) call method order is wrong for 3.13+ - will fix later
setup_finally.extend(
create_ctx: List[Instruction] = []
_initial_push_null(create_ctx)
create_ctx.extend(
[
*load_args,
*create_call_function(len(load_args), False),
create_instruction("STORE_FAST", argval=ctx_name),
create_instruction("LOAD_FAST", argval=ctx_name),
create_load_method("__enter__"),
*create_call_method(0),
create_instruction("POP_TOP"),
]
)
if sys.version_info < (3, 11):
setup_finally.append(
create_instruction("SETUP_FINALLY", target=except_jump_target)
)
else:
exn_tab_begin = create_instruction("NOP")
exn_tab_end = create_instruction("NOP")
exn_tab_begin.exn_tab_entry = InstructionExnTabEntry(
exn_tab_begin,
exn_tab_end,
except_jump_target,
self.stack_index + 1,
False,
)
setup_finally.append(exn_tab_begin)
def create_reset():
return [
create_instruction("LOAD_FAST", argval=ctx_name),
create_load_method("__exit__"),
create_instruction("LOAD_CONST", argval=None),
create_dup_top(),
create_dup_top(),
*create_call_method(3),
create_instruction("POP_TOP"),
]
if sys.version_info < (3, 9):
epilogue = [
create_instruction("POP_BLOCK"),
create_instruction("BEGIN_FINALLY"),
except_jump_target,
*create_reset(),
create_instruction("END_FINALLY"),
]
elif sys.version_info < (3, 11):
epilogue = [
create_instruction("POP_BLOCK"),
*create_reset(),
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
except_jump_target,
*create_reset(),
create_instruction("RERAISE"),
cleanup_complete_jump_target,
]
else:
finally_exn_tab_end = create_instruction("RERAISE", arg=0)
finally_exn_tab_target = create_instruction("COPY", arg=3)
except_jump_target.exn_tab_entry = InstructionExnTabEntry(
except_jump_target,
finally_exn_tab_end,
finally_exn_tab_target,
self.stack_index + 2,
True,
)
epilogue = [
exn_tab_end,
*create_reset(),
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
except_jump_target, # PUSH_EXC_INFO
*create_reset(),
finally_exn_tab_end, # RERAISE 0
finally_exn_tab_target, # COPY 3
create_instruction("POP_EXCEPT"),
create_instruction("RERAISE", arg=1),
cleanup_complete_jump_target,
]
def _template(ctx, dummy):
ctx.__enter__()
try:
dummy
finally:
ctx.__exit__(None, None, None)
setup_try_finally, epilogue = _bytecode_from_template_with_split(
_template, self.stack_index, varname_map={"ctx": ctx_name}
)
cleanup[:] = epilogue + cleanup
return setup_finally
return create_ctx + setup_try_finally
def __call__(self, code_options, cleanup):
"""
@ -283,129 +173,46 @@ class ReenterWith:
create_instruction("LOAD_CONST", argval=val)
for val in self.target_values
]
if sys.version_info < (3, 9):
with_cleanup_start = create_instruction("WITH_CLEANUP_START")
begin_finally = create_instruction("BEGIN_FINALLY")
cleanup[:] = [
create_instruction("POP_BLOCK"),
begin_finally,
with_cleanup_start,
create_instruction("WITH_CLEANUP_FINISH"),
create_instruction("END_FINALLY"),
] + cleanup
return [
create_ctx: List[Instruction] = []
_initial_push_null(create_ctx)
create_ctx.extend(
[
*load_args,
create_instruction("CALL_FUNCTION", arg=len(load_args)),
create_instruction("SETUP_WITH", target=with_cleanup_start),
create_instruction("POP_TOP"),
], None
elif sys.version_info < (3, 11):
with_except_start = create_instruction("WITH_EXCEPT_START")
pop_top_after_with_except_start = create_instruction("POP_TOP")
*create_call_function(len(load_args), False),
]
)
cleanup_complete_jump_target = create_instruction("NOP")
def _template(ctx, dummy):
with ctx:
dummy
cleanup[:] = [
create_instruction("POP_BLOCK"),
create_instruction("LOAD_CONST", argval=None),
create_instruction("DUP_TOP"),
create_instruction("DUP_TOP"),
create_instruction("CALL_FUNCTION", arg=3),
create_instruction("POP_TOP"),
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
with_except_start,
create_instruction(
"POP_JUMP_IF_TRUE", target=pop_top_after_with_except_start
),
create_instruction("RERAISE"),
pop_top_after_with_except_start,
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
create_instruction("POP_EXCEPT"),
create_instruction("POP_TOP"),
cleanup_complete_jump_target,
] + cleanup
setup_with, epilogue = _bytecode_from_template_with_split(
_template, self.stack_index
)
cleanup[:] = epilogue + cleanup
return [
*load_args,
create_instruction("CALL_FUNCTION", arg=len(load_args)),
create_instruction("SETUP_WITH", target=with_except_start),
create_instruction("POP_TOP"),
], None
else:
pop_top_after_with_except_start = create_instruction("POP_TOP")
cleanup_complete_jump_target = create_instruction("NOP")
load_fast_ctx_inst = next(
(
inst
for inst in setup_with
if inst.opname == "LOAD_FAST" and inst.argval == "ctx"
),
None,
)
assert load_fast_ctx_inst is not None
# ctx already loaded on stack before the template - no need to LOAD_FAST
overwrite_instruction(load_fast_ctx_inst, [create_instruction("NOP")])
def create_load_none():
return create_instruction("LOAD_CONST", argval=None)
# 3.11+ only
push_exc_info_gen = (
inst for inst in epilogue if inst.opname == "PUSH_EXC_INFO"
)
push_exc_info_inst = next(push_exc_info_gen, None)
# expect only 1 PUSH_EXC_INFO in epilogue
assert next(push_exc_info_gen, None) is None
exn_tab_1_begin = create_instruction("POP_TOP")
exn_tab_1_end = create_instruction("NOP")
exn_tab_1_target = create_instruction("PUSH_EXC_INFO")
exn_tab_2_end = create_instruction("RERAISE", arg=2)
exn_tab_2_target = create_instruction("COPY", arg=3)
exn_tab_1_begin.exn_tab_entry = InstructionExnTabEntry(
exn_tab_1_begin,
exn_tab_1_end,
exn_tab_1_target,
self.stack_index + 1,
True,
)
exn_tab_1_target.exn_tab_entry = InstructionExnTabEntry(
exn_tab_1_target,
exn_tab_2_end,
exn_tab_2_target,
self.stack_index + 3,
True,
)
pop_top_after_with_except_start.exn_tab_entry = InstructionExnTabEntry(
pop_top_after_with_except_start,
pop_top_after_with_except_start,
exn_tab_2_target,
self.stack_index + 3,
True,
)
cleanup[:] = [
exn_tab_1_end,
create_load_none(),
create_load_none(),
create_load_none(),
*create_call_function(2, False),
create_instruction("POP_TOP"),
create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target),
exn_tab_1_target, # PUSH_EXC_INFO
create_instruction("WITH_EXCEPT_START"),
create_instruction(
"POP_JUMP_FORWARD_IF_TRUE"
if sys.version_info < (3, 12)
else "POP_JUMP_IF_TRUE",
target=pop_top_after_with_except_start,
),
exn_tab_2_end, # RERAISE 2
exn_tab_2_target, # COPY 3
create_instruction("POP_EXCEPT"),
create_instruction("RERAISE", arg=1),
pop_top_after_with_except_start,
create_instruction("POP_EXCEPT"),
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
cleanup_complete_jump_target,
] + cleanup
ret: List[Instruction] = []
_initial_push_null(ret)
ret.extend(
[
*load_args,
*create_call_function(len(load_args), False),
create_instruction("BEFORE_WITH"),
exn_tab_1_begin, # POP_TOP
]
)
return ret, exn_tab_1_target
return create_ctx + setup_with, push_exc_info_inst
@dataclasses.dataclass

View File

@ -650,7 +650,7 @@ def break_graph_if_unsupported(*, push):
assert b.with_context is not None
assert isinstance(b.with_context, (ContextWrappingVariable))
b.with_context.reconstruct_type(cg)
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup))
self.output.add_output_instructions(cg.get_instructions())
del cg