diff --git a/docs/source/torch.compiler_faq.rst b/docs/source/torch.compiler_faq.rst index 59cbaacad35..103d99a0c3d 100644 --- a/docs/source/torch.compiler_faq.rst +++ b/docs/source/torch.compiler_faq.rst @@ -60,7 +60,7 @@ Do I still need to export whole graphs? For the vast majority of models you probably don’t and you can use ``torch.compile()`` as is but there are a few situations where full graphs are necessary and you can can ensure a full graph by simply -running ``torch.compile(..., nopython=True)``. These situations include: +running ``torch.compile(..., fullgraph=True)``. These situations include: * Large scale training runs, such as $250K+ that require pipeline parallelism and other advanced sharding strategies. @@ -245,22 +245,29 @@ that are encountered. Here is an example usage: if b.sum() < 0: b = b * -1 return x * b - explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10)) + explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10)) print(explanation) """ - Dynamo produced 3 graphs, with 2 graph break and 6 ops. - Break reasons: - 1. call_function BuiltinVariable(print) [ConstantVariable(str)] {} - File "t2.py", line 16, in toy_example - print("woo") + Graph Count: 3 + Graph Break Count: 2 + Op Count: 5 + Break Reasons: + Break Reason 1: + Reason: builtin: print [] False + User Stack: + + Break Reason 2: + Reason: generic_jump TensorVariable() + User Stack: + + Ops per Graph: + ... + Out Guards: + ... + """ - 2. generic_jump - File "t2.py", line 17, in toy_example - if b.sum() < 0: - """ - -To throw an error on the first graph break encountered you can use -disable python fallback by using ``nopython=True``, this should be +To throw an error on the first graph break encountered you can +disable python fallbacks by using ``fullgraph=True``, this should be familiar if you’ve worked with export based compilers. .. code-block:: python @@ -268,7 +275,7 @@ familiar if you’ve worked with export based compilers. def toy_example(a, b): ... - torch.compile(toy_example, fullgraph=True, backend=) + torch.compile(toy_example, fullgraph=True, backend=)(a, b) Why didn’t my code recompile when I changed it? ----------------------------------------------- diff --git a/docs/source/torch.compiler_troubleshooting.rst b/docs/source/torch.compiler_troubleshooting.rst index 0614946340d..eec046bc92a 100644 --- a/docs/source/torch.compiler_troubleshooting.rst +++ b/docs/source/torch.compiler_troubleshooting.rst @@ -612,21 +612,26 @@ that are encountered. Here is an example usage: if b.sum() < 0: b = b * -1 return x * b - explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose = ( - dynamo.explain(toy_example, torch.randn(10), torch.randn(10)) - ) + explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10)) print(explanation_verbose) """ - Dynamo produced 3 graphs, with 2 graph breaks and 6 ops. - Break reasons: - 1. call_function BuiltinVariable(print) [ConstantVariable(str)] {} - File "t2.py", line 16, in toy_example - print("woo") - - 2. generic_jump - File "t2.py", line 17, in toy_example - if b.sum() < 0: - """ + Graph Count: 3 + Graph Break Count: 2 + Op Count: 5 + Break Reasons: + Break Reason 1: + Reason: builtin: print [] False + User Stack: + + Break Reason 2: + Reason: generic_jump TensorVariable() + User Stack: + + Ops per Graph: + ... + Out Guards: + ... + """ Outputs include: @@ -634,7 +639,7 @@ Outputs include: - ``graphs`` - a list of graph modules which were successfully traced. - ``ops_per_graph`` - a list of lists where each sublist contains the ops that are run in the graph. -To throw an error on the first graph break encountered, use the ``nopython`` +To throw an error on the first graph break encountered, use the ``fullgraph`` mode. This mode disables TorchDynamo’s Python fallback, and only succeeds if the entire program is convertible into a single graph. Example usage: @@ -644,7 +649,7 @@ usage: def toy_example(a, b): ... - compiled_toy = torch.compile(toy_example, fullgraph=True, backend=) + compiled_toy = torch.compile(toy_example, fullgraph=True, backend=)(a, b) Excessive Recompilation ----------------------- diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 7c6a244e8ae..174e9f44e42 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -670,20 +670,6 @@ def explain(f, *extra_args, **extra_kwargs): opt_f(*args, **kwargs) graph_count = len(graphs) - - # For the explanation summary, dedupe reasons by the innermost stack frame and dedupe by it. - deduped_reasons = {} - for reason in break_reasons: - innermost_frame = reason.user_stack[-1] - # __repr__ uniquely identifies a FrameSummary so we can use it for deduping - deduped_reasons[repr(innermost_frame)] = reason - - formatted_list = "" - for idx, break_reason in enumerate(deduped_reasons.values()): - formatted_stack = "".join(traceback.format_list(break_reason.user_stack)) - msg = f"{idx + 1}. Reason: {break_reason.reason}\n User Stack: {formatted_stack}\n" - formatted_list += msg - graph_break_count = graph_count - 1 compile_time = compile_times(repr="str") diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 77edb451034..fff318155f8 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -718,7 +718,9 @@ class CleanupHook: name: str def __call__(self, *args): - CleanupManager.count -= 1 + # Make sure we're not shutting down + if CleanupManager is not None: + CleanupManager.count -= 1 del self.scope[self.name] @staticmethod