fix dynamo.explain examples (#122745)

`dynamo.explain()` was updated to return a structure but the docs weren't updated to match.

- Update the docs to use the new API
- Remove some dead code left when `explain` was updated.
- Drive-by: Fix some `nopython` uses that I noticed
- Drive-by: I noticed an ignored error coming from CleanupHook on shutdown - make it check the global before setting it.

Fixes #122573

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122745
Approved by: https://github.com/jansel
This commit is contained in:
Aaron Orenstein 2024-03-27 08:14:55 -07:00 committed by PyTorch MergeBot
parent a54ea7bbd8
commit a8b7480f0d
4 changed files with 45 additions and 45 deletions

View File

@ -60,7 +60,7 @@ Do I still need to export whole graphs?
For the vast majority of models you probably dont and you can use
``torch.compile()`` as is but there are a few situations where
full graphs are necessary and you can can ensure a full graph by simply
running ``torch.compile(..., nopython=True)``. These situations include:
running ``torch.compile(..., fullgraph=True)``. These situations include:
* Large scale training runs, such as $250K+ that require pipeline parallelism
and other advanced sharding strategies.
@ -245,22 +245,29 @@ that are encountered. Here is an example usage:
if b.sum() < 0:
b = b * -1
return x * b
explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10))
explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
print(explanation)
"""
Dynamo produced 3 graphs, with 2 graph break and 6 ops.
Break reasons:
1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}
File "t2.py", line 16, in toy_example
print("woo")
Graph Count: 3
Graph Break Count: 2
Op Count: 5
Break Reasons:
Break Reason 1:
Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
User Stack:
<FrameSummary file foo.py, line 5 in toy_example>
Break Reason 2:
Reason: generic_jump TensorVariable()
User Stack:
<FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
Ops per Graph:
...
Out Guards:
...
"""
2. generic_jump
File "t2.py", line 17, in toy_example
if b.sum() < 0:
"""
To throw an error on the first graph break encountered you can use
disable python fallback by using ``nopython=True``, this should be
To throw an error on the first graph break encountered you can
disable python fallbacks by using ``fullgraph=True``, this should be
familiar if youve worked with export based compilers.
.. code-block:: python
@ -268,7 +275,7 @@ familiar if youve worked with export based compilers.
def toy_example(a, b):
...
torch.compile(toy_example, fullgraph=True, backend=<compiler>)
torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)
Why didnt my code recompile when I changed it?
-----------------------------------------------

View File

@ -612,21 +612,26 @@ that are encountered. Here is an example usage:
if b.sum() < 0:
b = b * -1
return x * b
explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose = (
dynamo.explain(toy_example, torch.randn(10), torch.randn(10))
)
explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
print(explanation_verbose)
"""
Dynamo produced 3 graphs, with 2 graph breaks and 6 ops.
Break reasons:
1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}
File "t2.py", line 16, in toy_example
print("woo")
2. generic_jump
File "t2.py", line 17, in toy_example
if b.sum() < 0:
"""
Graph Count: 3
Graph Break Count: 2
Op Count: 5
Break Reasons:
Break Reason 1:
Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
User Stack:
<FrameSummary file foo.py, line 5 in toy_example>
Break Reason 2:
Reason: generic_jump TensorVariable()
User Stack:
<FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
Ops per Graph:
...
Out Guards:
...
"""
Outputs include:
@ -634,7 +639,7 @@ Outputs include:
- ``graphs`` - a list of graph modules which were successfully traced.
- ``ops_per_graph`` - a list of lists where each sublist contains the ops that are run in the graph.
To throw an error on the first graph break encountered, use the ``nopython``
To throw an error on the first graph break encountered, use the ``fullgraph``
mode. This mode disables TorchDynamos Python fallback, and only
succeeds if the entire program is convertible into a single graph. Example
usage:
@ -644,7 +649,7 @@ usage:
def toy_example(a, b):
...
compiled_toy = torch.compile(toy_example, fullgraph=True, backend=<compiler>)
compiled_toy = torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)
Excessive Recompilation
-----------------------

View File

@ -670,20 +670,6 @@ def explain(f, *extra_args, **extra_kwargs):
opt_f(*args, **kwargs)
graph_count = len(graphs)
# For the explanation summary, dedupe reasons by the innermost stack frame and dedupe by it.
deduped_reasons = {}
for reason in break_reasons:
innermost_frame = reason.user_stack[-1]
# __repr__ uniquely identifies a FrameSummary so we can use it for deduping
deduped_reasons[repr(innermost_frame)] = reason
formatted_list = ""
for idx, break_reason in enumerate(deduped_reasons.values()):
formatted_stack = "".join(traceback.format_list(break_reason.user_stack))
msg = f"{idx + 1}. Reason: {break_reason.reason}\n User Stack: {formatted_stack}\n"
formatted_list += msg
graph_break_count = graph_count - 1
compile_time = compile_times(repr="str")

View File

@ -718,7 +718,9 @@ class CleanupHook:
name: str
def __call__(self, *args):
CleanupManager.count -= 1
# Make sure we're not shutting down
if CleanupManager is not None:
CleanupManager.count -= 1
del self.scope[self.name]
@staticmethod