mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
`is_tracing` returns True during dynamo tracing and False when run in Eager Pull Request resolved: https://github.com/pytorch/pytorch/pull/90329 Approved by: https://github.com/jansel
125 lines
3.0 KiB
Python
125 lines
3.0 KiB
Python
from . import allowed_functions, convert_frame, eval_frame, resume_execution
|
|
from .convert_frame import replay
|
|
from .eval_frame import (
|
|
assume_constant_result,
|
|
disable,
|
|
explain,
|
|
export,
|
|
optimize,
|
|
optimize_assert,
|
|
OptimizedModule,
|
|
reset_code,
|
|
run,
|
|
skip,
|
|
)
|
|
from .external_utils import is_compiling
|
|
from .utils import compilation_metrics, guard_failures, orig_code_map
|
|
|
|
__all__ = [
|
|
"allow_in_graph",
|
|
"assume_constant_result",
|
|
"disallow_in_graph",
|
|
"graph_break",
|
|
"optimize",
|
|
"optimize_assert",
|
|
"export",
|
|
"explain",
|
|
"run",
|
|
"replay",
|
|
"disable",
|
|
"reset",
|
|
"list_backends",
|
|
"skip",
|
|
"OptimizedModule",
|
|
"is_compiling",
|
|
]
|
|
|
|
|
|
def reset():
|
|
"""Clear all compile caches and restore initial state"""
|
|
for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen:
|
|
code = weak_code()
|
|
if code:
|
|
reset_code(code)
|
|
convert_frame.input_codes.clear()
|
|
convert_frame.output_codes.clear()
|
|
orig_code_map.clear()
|
|
guard_failures.clear()
|
|
resume_execution.ContinueExecutionCache.cache.clear()
|
|
eval_frame.most_recent_backend = None
|
|
compilation_metrics.clear()
|
|
|
|
|
|
def list_backends():
|
|
"""
|
|
Return valid strings that can be passed to::
|
|
|
|
@torch._dynamo.optimize(<backend>)
|
|
def foo(...):
|
|
....
|
|
"""
|
|
from .optimizations import BACKENDS
|
|
|
|
return [*sorted([*BACKENDS.keys(), "inductor"])]
|
|
|
|
|
|
def allow_in_graph(fn):
|
|
"""
|
|
Customize which functions TorchDynamo will include in the generated
|
|
graph. Similar to `torch.fx.wrap()`.
|
|
::
|
|
|
|
torch._dynamo.allow_in_graph(my_custom_function)
|
|
|
|
@torch._dynamo.optimize(...)
|
|
def fn(a):
|
|
x = torch.add(x, 1)
|
|
x = my_custom_function(x)
|
|
x = torch.add(x, 1)
|
|
return x
|
|
|
|
fn(...)
|
|
|
|
Will capture a single graph containing `my_custom_function()`.
|
|
"""
|
|
if isinstance(fn, (list, tuple)):
|
|
return [allow_in_graph(x) for x in fn]
|
|
assert callable(fn), "allow_in_graph expects a callable"
|
|
allowed_functions._allowed_function_ids.add(id(fn))
|
|
allowed_functions._disallowed_function_ids.remove(id(fn))
|
|
return fn
|
|
|
|
|
|
def disallow_in_graph(fn):
|
|
"""
|
|
Customize which functions TorchDynamo will exclude in the generated
|
|
graph and force a graph break on.
|
|
::
|
|
|
|
torch._dynamo.disallow_in_graph(torch.sub)
|
|
|
|
@torch._dynamo.optimize(...)
|
|
def fn(a):
|
|
x = torch.add(x, 1)
|
|
x = torch.sub(x, 1)
|
|
x = torch.add(x, 1)
|
|
return x
|
|
|
|
fn(...)
|
|
|
|
Will break the graph on `torch.sub`, and give two graphs each with a
|
|
single `torch.add()` op.
|
|
"""
|
|
if isinstance(fn, (list, tuple)):
|
|
return [disallow_in_graph(x) for x in fn]
|
|
assert callable(fn), "disallow_in_graph expects a callable"
|
|
allowed_functions._allowed_function_ids.remove(id(fn))
|
|
allowed_functions._disallowed_function_ids.add(id(fn))
|
|
return fn
|
|
|
|
|
|
@disallow_in_graph
|
|
def graph_break():
|
|
"""Force a graph break"""
|
|
pass
|