mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix dynamo.explain examples (#122745)
`dynamo.explain()` was updated to return a structure but the docs weren't updated to match. - Update the docs to use the new API - Remove some dead code left when `explain` was updated. - Drive-by: Fix some `nopython` uses that I noticed - Drive-by: I noticed an ignored error coming from CleanupHook on shutdown - make it check the global before setting it. Fixes #122573 Pull Request resolved: https://github.com/pytorch/pytorch/pull/122745 Approved by: https://github.com/jansel
This commit is contained in:
parent
a54ea7bbd8
commit
a8b7480f0d
|
|
@ -60,7 +60,7 @@ Do I still need to export whole graphs?
|
|||
For the vast majority of models you probably don’t and you can use
|
||||
``torch.compile()`` as is but there are a few situations where
|
||||
full graphs are necessary and you can can ensure a full graph by simply
|
||||
running ``torch.compile(..., nopython=True)``. These situations include:
|
||||
running ``torch.compile(..., fullgraph=True)``. These situations include:
|
||||
|
||||
* Large scale training runs, such as $250K+ that require pipeline parallelism
|
||||
and other advanced sharding strategies.
|
||||
|
|
@ -245,22 +245,29 @@ that are encountered. Here is an example usage:
|
|||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10))
|
||||
explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
|
||||
print(explanation)
|
||||
"""
|
||||
Dynamo produced 3 graphs, with 2 graph break and 6 ops.
|
||||
Break reasons:
|
||||
1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}
|
||||
File "t2.py", line 16, in toy_example
|
||||
print("woo")
|
||||
Graph Count: 3
|
||||
Graph Break Count: 2
|
||||
Op Count: 5
|
||||
Break Reasons:
|
||||
Break Reason 1:
|
||||
Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
|
||||
User Stack:
|
||||
<FrameSummary file foo.py, line 5 in toy_example>
|
||||
Break Reason 2:
|
||||
Reason: generic_jump TensorVariable()
|
||||
User Stack:
|
||||
<FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
|
||||
Ops per Graph:
|
||||
...
|
||||
Out Guards:
|
||||
...
|
||||
"""
|
||||
|
||||
2. generic_jump
|
||||
File "t2.py", line 17, in toy_example
|
||||
if b.sum() < 0:
|
||||
"""
|
||||
|
||||
To throw an error on the first graph break encountered you can use
|
||||
disable python fallback by using ``nopython=True``, this should be
|
||||
To throw an error on the first graph break encountered you can
|
||||
disable python fallbacks by using ``fullgraph=True``, this should be
|
||||
familiar if you’ve worked with export based compilers.
|
||||
|
||||
.. code-block:: python
|
||||
|
|
@ -268,7 +275,7 @@ familiar if you’ve worked with export based compilers.
|
|||
def toy_example(a, b):
|
||||
...
|
||||
|
||||
torch.compile(toy_example, fullgraph=True, backend=<compiler>)
|
||||
torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)
|
||||
|
||||
Why didn’t my code recompile when I changed it?
|
||||
-----------------------------------------------
|
||||
|
|
|
|||
|
|
@ -612,21 +612,26 @@ that are encountered. Here is an example usage:
|
|||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose = (
|
||||
dynamo.explain(toy_example, torch.randn(10), torch.randn(10))
|
||||
)
|
||||
explanation = dynamo.explain(toy_example)(torch.randn(10), torch.randn(10))
|
||||
print(explanation_verbose)
|
||||
"""
|
||||
Dynamo produced 3 graphs, with 2 graph breaks and 6 ops.
|
||||
Break reasons:
|
||||
1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}
|
||||
File "t2.py", line 16, in toy_example
|
||||
print("woo")
|
||||
|
||||
2. generic_jump
|
||||
File "t2.py", line 17, in toy_example
|
||||
if b.sum() < 0:
|
||||
"""
|
||||
Graph Count: 3
|
||||
Graph Break Count: 2
|
||||
Op Count: 5
|
||||
Break Reasons:
|
||||
Break Reason 1:
|
||||
Reason: builtin: print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
|
||||
User Stack:
|
||||
<FrameSummary file foo.py, line 5 in toy_example>
|
||||
Break Reason 2:
|
||||
Reason: generic_jump TensorVariable()
|
||||
User Stack:
|
||||
<FrameSummary file foo.py, line 6 in torch_dynamo_resume_in_toy_example_at_5>
|
||||
Ops per Graph:
|
||||
...
|
||||
Out Guards:
|
||||
...
|
||||
"""
|
||||
|
||||
Outputs include:
|
||||
|
||||
|
|
@ -634,7 +639,7 @@ Outputs include:
|
|||
- ``graphs`` - a list of graph modules which were successfully traced.
|
||||
- ``ops_per_graph`` - a list of lists where each sublist contains the ops that are run in the graph.
|
||||
|
||||
To throw an error on the first graph break encountered, use the ``nopython``
|
||||
To throw an error on the first graph break encountered, use the ``fullgraph``
|
||||
mode. This mode disables TorchDynamo’s Python fallback, and only
|
||||
succeeds if the entire program is convertible into a single graph. Example
|
||||
usage:
|
||||
|
|
@ -644,7 +649,7 @@ usage:
|
|||
def toy_example(a, b):
|
||||
...
|
||||
|
||||
compiled_toy = torch.compile(toy_example, fullgraph=True, backend=<compiler>)
|
||||
compiled_toy = torch.compile(toy_example, fullgraph=True, backend=<compiler>)(a, b)
|
||||
|
||||
Excessive Recompilation
|
||||
-----------------------
|
||||
|
|
|
|||
|
|
@ -670,20 +670,6 @@ def explain(f, *extra_args, **extra_kwargs):
|
|||
opt_f(*args, **kwargs)
|
||||
|
||||
graph_count = len(graphs)
|
||||
|
||||
# For the explanation summary, dedupe reasons by the innermost stack frame and dedupe by it.
|
||||
deduped_reasons = {}
|
||||
for reason in break_reasons:
|
||||
innermost_frame = reason.user_stack[-1]
|
||||
# __repr__ uniquely identifies a FrameSummary so we can use it for deduping
|
||||
deduped_reasons[repr(innermost_frame)] = reason
|
||||
|
||||
formatted_list = ""
|
||||
for idx, break_reason in enumerate(deduped_reasons.values()):
|
||||
formatted_stack = "".join(traceback.format_list(break_reason.user_stack))
|
||||
msg = f"{idx + 1}. Reason: {break_reason.reason}\n User Stack: {formatted_stack}\n"
|
||||
formatted_list += msg
|
||||
|
||||
graph_break_count = graph_count - 1
|
||||
compile_time = compile_times(repr="str")
|
||||
|
||||
|
|
|
|||
|
|
@ -718,7 +718,9 @@ class CleanupHook:
|
|||
name: str
|
||||
|
||||
def __call__(self, *args):
|
||||
CleanupManager.count -= 1
|
||||
# Make sure we're not shutting down
|
||||
if CleanupManager is not None:
|
||||
CleanupManager.count -= 1
|
||||
del self.scope[self.name]
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user