diff --git a/test/dynamo/test_bytecode_utils.py b/test/dynamo/test_bytecode_utils.py index 3bbf7270b06..0e813c88378 100644 --- a/test/dynamo/test_bytecode_utils.py +++ b/test/dynamo/test_bytecode_utils.py @@ -8,7 +8,7 @@ import unittest import torch import torch._dynamo.test_case from torch._dynamo import bytecode_analysis, bytecode_transformation -from torch._dynamo.testing import skipIfNotPy311 +from torch._dynamo.testing import skipIfNotPy311, skipIfNotPy312 class BytecodeTests(torch._dynamo.test_case.TestCase): @@ -414,6 +414,119 @@ def fn(): self.assertEqual(tab[0].end, 4) self.assertEqual(tab[0].target, 6) + def test_bytecode_from_template(self): + def fn(d1): + for k, v in d1.items(): + d2[k] = v + + varname_map = {"d1": "var1", "d2": "var2", "k": "var3", "v": "var4"} + insts = bytecode_transformation.bytecode_from_template(fn, varname_map) + for inst in insts: + self.assertIsNone(inst.starts_line) + if inst.opname.startswith("LOAD"): + self.assertNotIn(inst.argval, varname_map) + if inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR"): + self.assertIsNone(inst.arg) + self.assertFalse(inst.opname.startswith("RETURN")) + + @skipIfNotPy311 + def test_bytecode_from_template_noprefix(self): + # Test that 3.11+ prefix instructions are removed + def gen_fn(): + cl = None + + def fn(): + return cl + + return fn + + fn = gen_fn() + + dis_insts = list(dis.get_instructions(fn)) + names = {inst.opname for inst in dis_insts} + self.assertIn("RESUME", names) + self.assertIn("COPY_FREE_VARS", names) + + insts = bytecode_transformation.bytecode_from_template(fn) + names = {inst.opname for inst in insts} + self.assertNotIn("RESUME", names) + self.assertNotIn("COPY_FREE_VARS", names) + + def test_bytecode_from_template_noreturn1(self): + # Test that functions with multiple returns will have their + # returns replaced with jumps to the end + def fn(): + if x: + return y + z = 3 + return z + + dis_insts = list(dis.get_instructions(fn)) + dis_returns = list(filter(lambda x: x.opname.startswith("RETURN"), dis_insts)) + self.assertGreater(len(dis_returns), 1) + self.assertTrue(dis_insts[-1].opname.startswith("RETURN")) + + insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) + self.assertEqual(insts[-1].opname, "NOP") + self.assertEqual(len(dis_insts), len(insts)) + for i0, i1 in zip(dis_insts, insts): + if i0.opname.startswith("RETURN"): + if i1 is insts[-1]: + continue + self.assertIn("JUMP", i1.opname) + self.assertIs(i1.target, insts[-1]) + + # Should work with 3.10, but testing with 3.11+ is sufficient. + # In 3.8, `fn` ends with a RETURN_VALUE. + @skipIfNotPy311 + def test_bytecode_from_template_noreturn2(self): + # Test function that doesn't end with RETURN_VALUE + def fn(): + if x: + return x + if x: + return x + raise RuntimeError + + dis_insts = list(dis.get_instructions(fn)) + self.assertFalse(dis_insts[-1].opname.startswith("RETURN")) + + insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) + self.assertEqual(insts[-1].opname, "NOP") + self.assertEqual(insts[-2].opname, dis_insts[-1].opname) + self.assertEqual(len(dis_insts) + 1, len(insts)) + for i0, i1 in zip(dis_insts, insts): + if i0.opname.startswith("RETURN"): + self.assertIn("JUMP", i1.opname) + self.assertIs(i1.target, insts[-1]) + + @skipIfNotPy312 + def test_bytecode_from_template_noreturn_const(self): + # Test 3.12+ RETURN_CONST + def fn(): + if x: + return 1 + return 0 + + dis_insts = list(dis.get_instructions(fn)) + dis_return_consts = list( + filter(lambda x: x.opname == "RETURN_CONST", dis_insts) + ) + self.assertGreater(len(dis_return_consts), 1) + self.assertTrue(dis_insts[-1].opname == "RETURN_CONST") + + insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) + self.assertEqual(insts[-1].opname, "NOP") + insts_i = 0 + for i, inst in enumerate(dis_insts): + if inst.opname == "RETURN_CONST": + self.assertEqual(insts[insts_i].opname, "LOAD_CONST") + insts_i += 1 + if insts_i != len(insts) - 1: + self.assertIn("JUMP", insts[insts_i].opname) + self.assertIs(insts[insts_i].target, insts[-1]) + insts_i += 1 + class BytecodeHookTests(torch._dynamo.test_case.TestCase): def test_bytecode_hook(self): diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index dec673b0e91..f07fe1c7a0e 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -1117,6 +1117,23 @@ def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=N instructions[i].arg = idx +def clear_instruction_args(instructions): + # Clear the instruction arg for instructions that have argvals. + # Useful for using dis'd bytecode within generated bytecode. + for inst in instructions: + if ( + inst.argval is not _NotProvided + and ( + inst.opcode in HAS_LOCAL + or inst.opcode in HAS_NAME + or inst.opcode in HAS_FREE + or inst.opcode in HAS_CONST + ) + and inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR", "LOAD_SUPER_ATTR") + ): + inst.arg = None + + def get_code_keys() -> List[str]: # Python 3.11 changes to code keys are not fully documented. # See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24 @@ -1247,3 +1264,100 @@ def unique_id(name) -> str: def is_generator(code: types.CodeType) -> bool: co_generator = 0x20 return (code.co_flags & co_generator) > 0 + + +def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True): + """Generates bytecode from a template function `fn` for use in + dynamo bytecode generation. + + For example, we can generate Python-version-independent bytecode + for looping through a dictionary and copying the values to a new dictionary. + + def template(d1, d2): + for k, v in d1.items(): + d2[k] = v + + + or a try block: + + def template(): + try: + dummy1 + except: + dummy2 + raise + dummy3 + + Args: + fn: a function template to generate bytecode from + varname_map: a mapping of `fn`'s varnames to new names. This + map will be applied to the generated bytecode's varnames. + For example, local variables in `fn` can be replaced with + new names that are generated by `OutputGraph.new_var`. + noreturn: remove all RETURN_* bytecodes and replace them with a jump + to the end of the bytecode. + noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive). + """ + insts = cleaned_instructions(fn.__code__) + clear_instruction_args(insts) + + if noprefix: + for i, inst in enumerate(insts): + if inst.opname == "RESUME": + insts = insts[i + 1 :] + break + + for inst in insts: + # If we don't reset starts_line, then the generated + # bytecode's line number will be based on fn's. + inst.starts_line = None + if varname_map and inst.argval in varname_map: + inst.argval = varname_map[inst.argval] + + if noreturn: + if sys.version_info >= (3, 12): + # replace RETURN_CONST with LOAD_CONST RETURN_VALUE + new_insts = [] + for inst in insts: + if inst.opname == "RETURN_CONST": + inst.opcode = dis.opmap["LOAD_CONST"] + inst.opname = "LOAD_CONST" + new_insts.append(inst) + # no need to propagate target/exn table + new_insts.append(create_instruction("RETURN_VALUE")) + else: + new_insts.append(inst) + insts = new_insts + + returns = [] + for inst in insts: + if inst.opname == "RETURN_VALUE": + returns.append(inst) + + if len(returns) == 1 and returns[0] is insts[-1]: + # only 1 return at the end - just pop it + insts.pop(-1) + elif len(returns) > 0: + # create jump target - if the last inst is a return, + # we can replace it with a NOP and make that the jump target. + if insts[-1] is returns[-1]: + insts[-1].opname = "NOP" + insts[-1].opcode = dis.opmap["NOP"] + insts[-1].arg = None + insts[-1].argval = _NotProvided + returns.pop(-1) + else: + insts.append(create_instruction("NOP")) + + # replace returns with jumps + for inst in returns: + # don't replace inst with new instruction + # due to targetting/exn table/etc. + jump_inst = create_jump_absolute(insts[-1]) + inst.opname = jump_inst.opname + inst.opcode = jump_inst.opcode + inst.arg = jump_inst.arg + inst.argval = jump_inst.argval + inst.target = jump_inst.target + + return insts diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 9e9abe84228..99b6607afea 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -343,6 +343,12 @@ def skipIfNotPy311(fn): return unittest.skip(fn) +def skipIfNotPy312(fn): + if sys.version_info >= (3, 12): + return fn + return unittest.skip(fn) + + def xfailIfPy312(fn): if sys.version_info >= (3, 12): return unittest.expectedFailure(fn)