pytorch/torch/_dynamo/__init__.py
Michael Lazos 9c4189f82d [dynamo] Add is_compiling for dynamo (#90329)
`is_tracing` returns True during dynamo tracing and False when run in Eager

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90329
Approved by: https://github.com/jansel
2022-12-09 20:19:41 +00:00

125 lines
3.0 KiB
Python

from . import allowed_functions, convert_frame, eval_frame, resume_execution
from .convert_frame import replay
from .eval_frame import (
assume_constant_result,
disable,
explain,
export,
optimize,
optimize_assert,
OptimizedModule,
reset_code,
run,
skip,
)
from .external_utils import is_compiling
from .utils import compilation_metrics, guard_failures, orig_code_map
__all__ = [
"allow_in_graph",
"assume_constant_result",
"disallow_in_graph",
"graph_break",
"optimize",
"optimize_assert",
"export",
"explain",
"run",
"replay",
"disable",
"reset",
"list_backends",
"skip",
"OptimizedModule",
"is_compiling",
]
def reset():
"""Clear all compile caches and restore initial state"""
for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen:
code = weak_code()
if code:
reset_code(code)
convert_frame.input_codes.clear()
convert_frame.output_codes.clear()
orig_code_map.clear()
guard_failures.clear()
resume_execution.ContinueExecutionCache.cache.clear()
eval_frame.most_recent_backend = None
compilation_metrics.clear()
def list_backends():
"""
Return valid strings that can be passed to::
@torch._dynamo.optimize(<backend>)
def foo(...):
....
"""
from .optimizations import BACKENDS
return [*sorted([*BACKENDS.keys(), "inductor"])]
def allow_in_graph(fn):
"""
Customize which functions TorchDynamo will include in the generated
graph. Similar to `torch.fx.wrap()`.
::
torch._dynamo.allow_in_graph(my_custom_function)
@torch._dynamo.optimize(...)
def fn(a):
x = torch.add(x, 1)
x = my_custom_function(x)
x = torch.add(x, 1)
return x
fn(...)
Will capture a single graph containing `my_custom_function()`.
"""
if isinstance(fn, (list, tuple)):
return [allow_in_graph(x) for x in fn]
assert callable(fn), "allow_in_graph expects a callable"
allowed_functions._allowed_function_ids.add(id(fn))
allowed_functions._disallowed_function_ids.remove(id(fn))
return fn
def disallow_in_graph(fn):
"""
Customize which functions TorchDynamo will exclude in the generated
graph and force a graph break on.
::
torch._dynamo.disallow_in_graph(torch.sub)
@torch._dynamo.optimize(...)
def fn(a):
x = torch.add(x, 1)
x = torch.sub(x, 1)
x = torch.add(x, 1)
return x
fn(...)
Will break the graph on `torch.sub`, and give two graphs each with a
single `torch.add()` op.
"""
if isinstance(fn, (list, tuple)):
return [disallow_in_graph(x) for x in fn]
assert callable(fn), "disallow_in_graph expects a callable"
allowed_functions._allowed_function_ids.remove(id(fn))
allowed_functions._disallowed_function_ids.add(id(fn))
return fn
@disallow_in_graph
def graph_break():
"""Force a graph break"""
pass