mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes #114674. The error is because cached_backends is a thread-local object, when it's accessed from the other thread, we'll have a cache miss. The naive fix is to just return None and re-compiles when cache misses. This could also be related to making dynamo more thread-safe but I'm not sure if there an on-going effort or not. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114766 Approved by: https://github.com/IvanYashchuk, https://github.com/Neilblaze, https://github.com/jansel
86 lines
2.3 KiB
Python
86 lines
2.3 KiB
Python
import torch
|
|
from . import allowed_functions, convert_frame, eval_frame, resume_execution
|
|
from .backends.registry import list_backends, lookup_backend, register_backend
|
|
from .code_context import code_context
|
|
from .convert_frame import replay
|
|
from .decorators import (
|
|
allow_in_graph,
|
|
assume_constant_result,
|
|
disable,
|
|
disallow_in_graph,
|
|
forbid_in_graph,
|
|
graph_break,
|
|
mark_dynamic,
|
|
mark_static,
|
|
mark_static_address,
|
|
maybe_mark_dynamic,
|
|
run,
|
|
)
|
|
from .eval_frame import (
|
|
_reset_guarded_backend_cache,
|
|
explain,
|
|
export,
|
|
is_dynamo_supported,
|
|
optimize,
|
|
optimize_assert,
|
|
OptimizedModule,
|
|
reset_code,
|
|
)
|
|
from .external_utils import is_compiling
|
|
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
|
|
|
|
__all__ = [
|
|
"allow_in_graph",
|
|
"assume_constant_result",
|
|
"disallow_in_graph",
|
|
"forbid_in_graph",
|
|
"graph_break",
|
|
"mark_dynamic",
|
|
"maybe_mark_dynamic",
|
|
"mark_static",
|
|
"mark_static_address",
|
|
"optimize",
|
|
"optimize_assert",
|
|
"export",
|
|
"explain",
|
|
"run",
|
|
"replay",
|
|
"disable",
|
|
"reset",
|
|
"OptimizedModule",
|
|
"is_compiling",
|
|
"register_backend",
|
|
"list_backends",
|
|
"lookup_backend",
|
|
]
|
|
|
|
if torch.manual_seed is torch.random.manual_seed:
|
|
import torch.jit._builtins
|
|
|
|
# Wrap manual_seed with the disable decorator.
|
|
# Can't do it at its implementation due to dependency issues.
|
|
torch.manual_seed = disable(torch.manual_seed)
|
|
# Add the new manual_seed to the builtin registry.
|
|
torch.jit._builtins._register_builtin(torch.manual_seed, "aten::manual_seed")
|
|
|
|
|
|
def reset() -> None:
|
|
"""Clear all compile caches and restore initial state"""
|
|
with eval_frame.compile_lock:
|
|
for weak_code in (
|
|
convert_frame.input_codes.seen + convert_frame.output_codes.seen
|
|
):
|
|
code = weak_code()
|
|
if code:
|
|
reset_code(code)
|
|
convert_frame.input_codes.clear()
|
|
convert_frame.output_codes.clear()
|
|
orig_code_map.clear()
|
|
guard_failures.clear()
|
|
graph_break_reasons.clear()
|
|
resume_execution.ContinueExecutionCache.cache.clear()
|
|
_reset_guarded_backend_cache()
|
|
reset_frame_count()
|
|
torch._C._dynamo.compiled_autograd.clear_cache()
|
|
code_context.clear()
|