pytorch/torch/_dynamo/__init__.py
ydwu4 240f4b2d25 make __lookup_backend return None when cache misses (#114766)
Fixes #114674. The error is because cached_backends is a thread-local object, when it's accessed from the other thread, we'll have a cache miss. The naive fix is to just return None and re-compiles when cache misses. This could also be related to making dynamo more thread-safe but I'm not sure if there an on-going effort or not.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114766
Approved by: https://github.com/IvanYashchuk, https://github.com/Neilblaze, https://github.com/jansel
2023-12-07 00:25:01 +00:00

86 lines
2.3 KiB
Python

import torch
from . import allowed_functions, convert_frame, eval_frame, resume_execution
from .backends.registry import list_backends, lookup_backend, register_backend
from .code_context import code_context
from .convert_frame import replay
from .decorators import (
allow_in_graph,
assume_constant_result,
disable,
disallow_in_graph,
forbid_in_graph,
graph_break,
mark_dynamic,
mark_static,
mark_static_address,
maybe_mark_dynamic,
run,
)
from .eval_frame import (
_reset_guarded_backend_cache,
explain,
export,
is_dynamo_supported,
optimize,
optimize_assert,
OptimizedModule,
reset_code,
)
from .external_utils import is_compiling
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
__all__ = [
"allow_in_graph",
"assume_constant_result",
"disallow_in_graph",
"forbid_in_graph",
"graph_break",
"mark_dynamic",
"maybe_mark_dynamic",
"mark_static",
"mark_static_address",
"optimize",
"optimize_assert",
"export",
"explain",
"run",
"replay",
"disable",
"reset",
"OptimizedModule",
"is_compiling",
"register_backend",
"list_backends",
"lookup_backend",
]
if torch.manual_seed is torch.random.manual_seed:
import torch.jit._builtins
# Wrap manual_seed with the disable decorator.
# Can't do it at its implementation due to dependency issues.
torch.manual_seed = disable(torch.manual_seed)
# Add the new manual_seed to the builtin registry.
torch.jit._builtins._register_builtin(torch.manual_seed, "aten::manual_seed")
def reset() -> None:
"""Clear all compile caches and restore initial state"""
with eval_frame.compile_lock:
for weak_code in (
convert_frame.input_codes.seen + convert_frame.output_codes.seen
):
code = weak_code()
if code:
reset_code(code)
convert_frame.input_codes.clear()
convert_frame.output_codes.clear()
orig_code_map.clear()
guard_failures.clear()
graph_break_reasons.clear()
resume_execution.ContinueExecutionCache.cache.clear()
_reset_guarded_backend_cache()
reset_frame_count()
torch._C._dynamo.compiled_autograd.clear_cache()
code_context.clear()