diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 4172e1cb980..78723c975e3 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -67,7 +67,7 @@ from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._ordered_set import OrderedSet from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from . import config, exc, logging as torchdynamo_logging, variables +from . import config, exc, graph_break_hints, logging as torchdynamo_logging, variables from .backends.registry import CompiledFn, CompilerFn from .bytecode_transformation import ( create_call_function, @@ -1253,6 +1253,18 @@ class OutputGraph(OutputGraphGuardsState): log.debug("COMPILING GRAPH due to %s", reason) + if not all(block.can_restore() for block in tx.block_stack): + unimplemented_v2( + gb_type="Attempt to compile graph with unrecoverable block in the block stack", + context="", + explanation="Dynamo does not support graph breaking on context managers in " + "nested function calls. For Python <= 3.10, this graph break may have instead been " + "caused by attempting to graph break in a try block.", + hints=[ + *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK, + ], + ) + # prefix instructions (Python 3.11+) prefix_insts: list[Instruction] = [] if sys.version_info >= (3, 11): @@ -1295,8 +1307,6 @@ class OutputGraph(OutputGraphGuardsState): cur_tx: Optional[InstructionTranslatorBase] = tx while True: assert cur_tx is not None - # this should have been checked by the caller - assert all(block.can_restore() for block in cur_tx.block_stack) stack_values, restore_vars, meta = self._get_stack_values_to_restore( cur_tx, stack_pops )