diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index c82e75ac98a..801bb169254 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -1814,9 +1814,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) # Ensure no more re-compilation after the second automatic dynamic shape version. if i == 0: - self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) - else: self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) + else: + self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 4) @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 68954fb32a0..6517d716846 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -64,6 +64,7 @@ from torch.testing._internal.common_utils import ( skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, + skipIfWindows, TemporaryFileName, TEST_WITH_TORCHDYNAMO, TestCase, @@ -2226,6 +2227,9 @@ class FakeTensorDispatchCache(TestCase): lambda: torch.ops.aten.index(x, [None, idx_tensor1]), ) + @skipIfWindows( + msg="weird bug - cache may not be cleared after https://github.com/pytorch/pytorch/pull/154283" + ) @skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching") def test_invoke_subgraph(self): """ diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 3b1c0d93109..b46251a8c7c 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -325,6 +325,11 @@ skip_torchrec = True # Don't apply most trace_rules.py rules dont_skip_tracing = False +# If True, enforce fullgraph=True - raise errors on graph break +# NOTE: do not set manually - this is modified internally by Dynamo. +# Use the fullgraph option of torch.compile instead. +error_on_graph_break = False + # No longer used optimize_ddp_lazy_compile = False diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index a48e16325a7..71832c725b8 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -654,7 +654,7 @@ def convert_frame_assert( export_constraints: Optional[typing.Never] = None, package: Optional[CompilePackage] = None, ) -> ConvertFrameAssert: - """Fully convert a frame into an FX graph""" + """Fully convert a frame into an FX graph, raising an exception if we fail.""" return ConvertFrameAssert( compiler_fn, one_graph, export, export_constraints, package ) @@ -862,8 +862,10 @@ def _compile( code.co_filename, code.co_firstlineno, ) - if one_graph: - log.debug("No graph captured with one_graph=True") + if one_graph or config.error_on_graph_break: + log.debug( + "No graph captured with one_graph=True or torch._dynamo.config.error_on_graph_break=True" + ) return ConvertFrameReturn() assert distributed_state is None or distributed_state.all_states is not None, ( # type: ignore[has-type] @@ -1029,9 +1031,10 @@ def _compile( raise FailOnRecompileLimitHit( f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure" ) - elif one_graph: + elif one_graph or config.error_on_graph_break: raise FailOnRecompileLimitHit( - f"{limit_type} reached with one_graph=True. Excessive recompilations can degrade " + f"{limit_type} reached with one_graph=True or torch._dynamo.config.error_on_graph_break=True. " + "Excessive recompilations can degrade " "performance due to the compilation overhead of each recompilation. To monitor " "recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider " "increasing torch._dynamo.config.cache_size_limit to an appropriate value." @@ -1245,6 +1248,7 @@ class ConvertFrame: self, compiler_fn: CompilerFn, hooks: Hooks, + error_on_graph_break: bool, package: Optional[CompilePackage] = None, ) -> None: self._torchdynamo_orig_callable = compiler_fn @@ -1252,10 +1256,13 @@ class ConvertFrame: compiler_fn, one_graph=False, package=package ) self._hooks = hooks + self._error_on_graph_break = error_on_graph_break @property def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]: - return lambda backend: convert_frame(backend, self._hooks) + return lambda backend: convert_frame( + backend, self._hooks, self._error_on_graph_break + ) def __call__( self, @@ -1267,13 +1274,17 @@ class ConvertFrame: ) -> ConvertFrameReturn: input_codes.add(frame.f_code) counters["frames"]["total"] += 1 + prev_error_on_graph_break = config.error_on_graph_break try: + config.error_on_graph_break = self._error_on_graph_break result = self._inner_convert( frame, cache_entry, hooks, frame_state, skip=skip + 1 ) counters["frames"]["ok"] += 1 return result except Exception as e: + if config.error_on_graph_break: + raise # These two exception types are "soft" failure, in the sense that # we know this is due to something we didn't implement all the # way, scare the user less about it. That being said, if you @@ -1349,15 +1360,24 @@ class ConvertFrame: FrameAction.RUN_ONLY, FrameAction.RUN_ONLY ) ) + finally: + config.error_on_graph_break = prev_error_on_graph_break return ConvertFrameReturn() def convert_frame( - compiler_fn: CompilerFn, hooks: Hooks, package: Optional[CompilePackage] = None + compiler_fn: CompilerFn, + hooks: Hooks, + error_on_graph_break: bool, + package: Optional[CompilePackage] = None, ) -> ConvertFrame: - """Try to convert a frame into an FX graph, if error leave frame unmodified""" - return ConvertFrame(compiler_fn, hooks, package=package) + """Try to convert a frame into an FX graph, if error leave frame unmodified + + If error_on_graph_break=True, graph breaks become errors (resulting in an unmodified frame). + If error_on_graph_break=False, we will attempt to generate optimized and resume functions. + """ + return ConvertFrame(compiler_fn, hooks, error_on_graph_break, package=package) # TODO mlazos: add support for same args, or record them @@ -1370,7 +1390,9 @@ def replay(filename: str) -> None: record = ExecutionRecord.load(in_file) record.globals = dict(itertools.chain(record.globals.items(), globals().items())) + prev_error_on_graph_break = config.error_on_graph_break try: + config.error_on_graph_break = False _compile( record.code, record.globals, @@ -1390,6 +1412,7 @@ def replay(filename: str) -> None: ) finally: config.replay_record_enabled = original_replay_val + config.error_on_graph_break = prev_error_on_graph_break def first_real_inst_idx(code: CodeType) -> int: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 380291f9f5b..c0b602d7546 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -227,6 +227,7 @@ def _create_wrapped_callback(compiler_fn): convert_frame.convert_frame( # type: ignore[arg-type] compiler_fn, hooks, + False, ), hooks, ) @@ -1080,15 +1081,6 @@ def _optimize( ): return _NullDecorator() - if nopython: - return optimize_assert( - backend, - dynamic=dynamic, - hooks=hooks, - rebuild_ctx=rebuild_ctx, - package=package, - ) - backend = get_compiler_fn(backend) # Find if backend has any extra context manager @@ -1098,7 +1090,7 @@ def _optimize( # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can # be used by eval_frame.c to insert a guard on the backend. return _optimize_catch_errors( - convert_frame.convert_frame(backend, hooks=hooks, package=package), + convert_frame.convert_frame(backend, hooks, nopython, package=package), hooks, backend_ctx_ctor, dynamic=dynamic, @@ -2002,7 +1994,11 @@ def _optimize_assert( package=None, ): """ - The same as `torch._dynamo.optimize(backend, nopython=True)` + The same as `torch._dynamo.optimize(backend, nopython=True)`, + but ignores config.error_on_graph_break setting. + + Used for export, since we must always error on graph breaks and ignore + config.error_on_graph_break. """ backend = get_compiler_fn(backend) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 25eec7f3167..6ca10527d47 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3243,6 +3243,9 @@ class InstructionTranslatorBase( self.num_calls: dict[str, int] = {} # Flag to indicate whether tracing is used for export. self.export = export + # NOTE: one_graph is used for export/debugging to always force errors on graph breaks. + # For allow for fullgraph toggle during normal compile, config.error_on_graph_break + # is used instead. self.one_graph = False self.current_speculation = None @@ -3507,6 +3510,7 @@ class InstructionTranslator(InstructionTranslatorBase): return ( all(b.can_restore() for b in self.block_stack) and not self.one_graph + and not config.error_on_graph_break and not self.active_generic_context_managers ) @@ -3641,6 +3645,7 @@ class InstructionTranslator(InstructionTranslatorBase): and not self.symbolic_locals_contain_module_class() and not self.export and not self.one_graph + and not config.error_on_graph_break ): raise exc.SkipFrame("because no content in function call") diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index f368a1bd4f0..3101b3b9e87 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1006,7 +1006,7 @@ class AutogradEngineVariable(UserDefinedObjectVariable): ) -> "VariableTracker": if name == "queue_callback": if torch._dynamo.compiled_autograd.in_compiled_autograd_region: - assert tx.one_graph, ( + assert tx.one_graph or config.error_on_graph_break, ( "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" ) return variables.UserFunctionVariable(