mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Previously: https://github.com/pytorch/pytorch/pull/138052 but the implementation is done from scratch, so I open a new PR. This implements the ability to save and load profiles of automatic dynamic decisions, so on subsequent runs we can directly make something automatically dynamic. Unlike the previous implementation, this cache is never enabled by default; instead, you have to specify a "job id" that says it's OK to share results. We will be able to automatically populate this id for internal MAST jobs but for generic OSS users you will have to explicitly opt into it. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/139001 Approved by: https://github.com/oulgen
138 lines
4.1 KiB
Python
138 lines
4.1 KiB
Python
import torch
|
|
|
|
from . import convert_frame, eval_frame, resume_execution
|
|
from .backends.registry import list_backends, lookup_backend, register_backend
|
|
from .callback import callback_handler, on_compile_end, on_compile_start
|
|
from .code_context import code_context
|
|
from .convert_frame import replay
|
|
from .decorators import (
|
|
allow_in_graph,
|
|
assume_constant_result,
|
|
disable,
|
|
disallow_in_graph,
|
|
forbid_in_graph,
|
|
graph_break,
|
|
mark_dynamic,
|
|
mark_static,
|
|
mark_static_address,
|
|
maybe_mark_dynamic,
|
|
run,
|
|
set_stance,
|
|
substitute_in_graph,
|
|
)
|
|
from .eval_frame import (
|
|
_reset_guarded_backend_cache,
|
|
explain,
|
|
export,
|
|
is_dynamo_supported,
|
|
is_inductor_supported,
|
|
optimize,
|
|
optimize_assert,
|
|
OptimizedModule,
|
|
reset_code,
|
|
)
|
|
from .external_utils import is_compiling
|
|
from .mutation_guard import GenerationTracker
|
|
from .pgo import reset_code_state
|
|
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
|
|
|
|
|
|
# Register polyfill functions
|
|
from .polyfills import loader as _ # usort: skip # noqa: F401
|
|
|
|
|
|
__all__ = [
|
|
"allow_in_graph",
|
|
"assume_constant_result",
|
|
"disallow_in_graph",
|
|
"forbid_in_graph",
|
|
"substitute_in_graph",
|
|
"graph_break",
|
|
"mark_dynamic",
|
|
"maybe_mark_dynamic",
|
|
"mark_static",
|
|
"mark_static_address",
|
|
"optimize",
|
|
"optimize_assert",
|
|
"export",
|
|
"explain",
|
|
"run",
|
|
"replay",
|
|
"disable",
|
|
"set_stance",
|
|
"reset",
|
|
"OptimizedModule",
|
|
"is_compiling",
|
|
"register_backend",
|
|
"list_backends",
|
|
"lookup_backend",
|
|
]
|
|
|
|
if torch.manual_seed is torch.random.manual_seed:
|
|
import torch.jit._builtins
|
|
|
|
# Wrap manual_seed with the disable decorator.
|
|
# Can't do it at its implementation due to dependency issues.
|
|
torch.manual_seed = torch._disable_dynamo(torch.manual_seed)
|
|
# Add the new manual_seed to the builtin registry.
|
|
torch.jit._builtins._register_builtin(torch.manual_seed, "aten::manual_seed")
|
|
|
|
|
|
def reset() -> None:
|
|
"""
|
|
Clear all compile caches and restore initial state. This function is intended
|
|
to reset Dynamo's state *as if* you had started a fresh process invocation, which
|
|
makes it good for testing scenarios where you want to behave as if you started
|
|
a new process. It does NOT affect any file system caches.
|
|
|
|
NB: this does NOT reset logging state. Don't use this to test logging
|
|
initialization/reinitialization.
|
|
"""
|
|
# TODO: https://github.com/pytorch/pytorch/issues/139200
|
|
import logging
|
|
|
|
log = logging.getLogger(__name__)
|
|
log.info("torch._dynamo.reset")
|
|
with convert_frame.compile_lock:
|
|
reset_code_caches()
|
|
convert_frame.input_codes.clear()
|
|
reset_code_state()
|
|
convert_frame.output_codes.clear()
|
|
orig_code_map.clear()
|
|
guard_failures.clear()
|
|
graph_break_reasons.clear()
|
|
resume_execution.ContinueExecutionCache.cache.clear()
|
|
_reset_guarded_backend_cache()
|
|
reset_frame_count()
|
|
torch._C._dynamo.compiled_autograd.clear_cache()
|
|
convert_frame.FRAME_COUNTER = 0
|
|
convert_frame.FRAME_COMPILE_COUNTER.clear()
|
|
callback_handler.clear()
|
|
GenerationTracker.clear()
|
|
torch._dynamo.utils.warn_once_cache.clear()
|
|
torch._dynamo.utils.user_obj_id_to_weakref.clear()
|
|
torch._C._autograd._saved_tensors_hooks_set_tracing(False)
|
|
|
|
|
|
def reset_code_caches() -> None:
|
|
"""
|
|
Clears in-memory code cache, which is what stores compiled products. This
|
|
resets less state than :func:`reset` and is mostly only used for testing
|
|
purposes.
|
|
"""
|
|
# TODO: https://github.com/pytorch/pytorch/issues/139200
|
|
import logging
|
|
|
|
log = logging.getLogger(__name__)
|
|
log.info("torch._dynamo.reset_code_caches")
|
|
"""Clear compile caches that are keyed by code objects"""
|
|
with convert_frame.compile_lock:
|
|
reset_code_state()
|
|
for weak_code in (
|
|
convert_frame.input_codes.seen + convert_frame.output_codes.seen
|
|
):
|
|
code = weak_code()
|
|
if code:
|
|
reset_code(code)
|
|
code_context.clear()
|