mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Profile guided optimization for automatic_dynamic (#139001)
Previously: https://github.com/pytorch/pytorch/pull/138052 but the implementation is done from scratch, so I open a new PR. This implements the ability to save and load profiles of automatic dynamic decisions, so on subsequent runs we can directly make something automatically dynamic. Unlike the previous implementation, this cache is never enabled by default; instead, you have to specify a "job id" that says it's OK to share results. We will be able to automatically populate this id for internal MAST jobs but for generic OSS users you will have to explicitly opt into it. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Differential Revision: [D65065497](https://our.internmc.facebook.com/intern/diff/D65065497) Pull Request resolved: https://github.com/pytorch/pytorch/pull/139001 Approved by: https://github.com/oulgen
This commit is contained in:
parent
9c2ffce71a
commit
a6630bcf87
9
docs/source/torch.compiler.config.rst
Normal file
9
docs/source/torch.compiler.config.rst
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
.. currentmodule:: torch.compiler.config
|
||||
|
||||
|
||||
torch.compiler.config
|
||||
=====================
|
||||
|
||||
.. automodule:: torch.compiler.config
|
||||
|
||||
.. autodata:: torch.compiler.config.job_id
|
||||
|
|
@ -85,6 +85,7 @@ Read More
|
|||
|
||||
torch.compiler_get_started
|
||||
torch.compiler_api
|
||||
torch.compiler.config
|
||||
torch.compiler_fine_grain_apis
|
||||
torch.compiler_aot_inductor
|
||||
torch.compiler_inductor_profiling
|
||||
|
|
|
|||
126
test/dynamo/test_pgo.py
Normal file
126
test/dynamo/test_pgo.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch._dynamo.config
|
||||
import torch._dynamo.test_case
|
||||
import torch.compiler.config
|
||||
from torch._dynamo.testing import CompileCounter
|
||||
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache
|
||||
|
||||
|
||||
# LOL. https://github.com/pytorch/pytorch/issues/139252
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"mock_cache", os.path.join(os.path.dirname(__file__), "../inductor/mock_cache.py")
|
||||
)
|
||||
mock_cache = importlib.util.module_from_spec(spec)
|
||||
sys.modules["mock_cache"] = mock_cache
|
||||
spec.loader.exec_module(mock_cache)
|
||||
|
||||
|
||||
class PgoTest(torch._dynamo.test_case.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._test_stack = contextlib.ExitStack()
|
||||
self._test_stack.enter_context(torch.compiler.config.patch(job_id=self.id()))
|
||||
self._test_stack.enter_context(
|
||||
torch._dynamo.config.patch(automatic_dynamic_local_pgo=True)
|
||||
)
|
||||
if os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1":
|
||||
self._test_stack.enter_context(fresh_inductor_cache())
|
||||
mock_cache.PatchCaches.setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
torch._dynamo.reset()
|
||||
self._test_stack.close()
|
||||
mock_cache.PatchCaches.tearDown()
|
||||
|
||||
def reset(self):
|
||||
torch._dynamo.reset()
|
||||
clear_inductor_caches()
|
||||
|
||||
def test_basic(self):
|
||||
cnts = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
f(torch.randn(2, 3))
|
||||
f(torch.randn(2, 4))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
self.reset()
|
||||
cnts.clear()
|
||||
|
||||
f(torch.randn(2, 5))
|
||||
f(torch.randn(2, 6))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_distinct_compile_id(self):
|
||||
cnts = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
with torch.compiler.config.patch(job_id="foo"):
|
||||
f(torch.randn(2, 3))
|
||||
f(torch.randn(2, 4))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
self.reset()
|
||||
cnts.clear()
|
||||
|
||||
with torch.compiler.config.patch(job_id="bar"):
|
||||
f(torch.randn(2, 5))
|
||||
f(torch.randn(2, 6))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
torch._dynamo.reset()
|
||||
clear_inductor_caches()
|
||||
cnts.clear()
|
||||
|
||||
with torch.compiler.config.patch(job_id="foo"):
|
||||
f(torch.randn(2, 7))
|
||||
f(torch.randn(2, 8))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
# TODO: to test local need to ensure the local filesystem gets cleared out
|
||||
@torch._dynamo.config.patch(
|
||||
automatic_dynamic_remote_pgo=True, automatic_dynamic_local_pgo=False
|
||||
)
|
||||
def test_remote_basic(self):
|
||||
cnts = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
with mock_cache.PatchCaches():
|
||||
f(torch.randn(2, 3))
|
||||
f(torch.randn(2, 4))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(
|
||||
mock_cache.global_stats.dynamo_pgo, mock_cache.Stats(2, 0, 1)
|
||||
)
|
||||
|
||||
self.reset()
|
||||
cnts.clear()
|
||||
|
||||
f(torch.randn(2, 5))
|
||||
f(torch.randn(2, 6))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(
|
||||
mock_cache.global_stats.dynamo_pgo, mock_cache.Stats(2, 1, 1)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
|
@ -80,6 +80,7 @@ class _GlobalStats(threading.local):
|
|||
self.fx_graph = _GlobalItemStats()
|
||||
self.triton = _GlobalItemStats()
|
||||
self.aot_autograd = _GlobalItemStats()
|
||||
self.dynamo_pgo = _GlobalItemStats()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.autotune_local.reset()
|
||||
|
|
@ -88,6 +89,7 @@ class _GlobalStats(threading.local):
|
|||
self.fx_graph.reset()
|
||||
self.triton.reset()
|
||||
self.aot_autograd.reset()
|
||||
self.dynamo_pgo.reset()
|
||||
|
||||
def get_stat(self, name: str) -> _GlobalItemStats:
|
||||
return getattr(self, name)
|
||||
|
|
@ -100,6 +102,7 @@ class _GlobalStats(threading.local):
|
|||
("fx_graph", self.fx_graph),
|
||||
("triton", self.triton),
|
||||
("aot_autograd", self.aot_autograd),
|
||||
("dynamo_pgo", self.dynamo_pgo),
|
||||
)
|
||||
|
||||
print("Cache Stats:", file=sys.stderr)
|
||||
|
|
@ -215,6 +218,12 @@ class PatchCaches(contextlib.AbstractContextManager):
|
|||
)
|
||||
self._stack.enter_context(ctx)
|
||||
|
||||
ctx = patch(
|
||||
"torch._inductor.remote_cache.RemoteDynamoPGOCache.backend_override_cls",
|
||||
MockBackend.with_name("dynamo_pgo"),
|
||||
)
|
||||
self._stack.enter_context(ctx)
|
||||
|
||||
if config.is_fbcode():
|
||||
ctx = patch(
|
||||
"torch._inductor.fb.remote_cache.FbRemoteAutotuneCache.backend_override_cls",
|
||||
|
|
@ -246,6 +255,12 @@ class PatchCaches(contextlib.AbstractContextManager):
|
|||
)
|
||||
self._stack.enter_context(ctx)
|
||||
|
||||
ctx = patch(
|
||||
"torch._inductor.fb.remote_cache.FbRemoteDynamoPGOCache.backend_override_cls",
|
||||
MockBackend.with_name("dynamo_pgo"),
|
||||
)
|
||||
self._stack.enter_context(ctx)
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from .eval_frame import (
|
|||
)
|
||||
from .external_utils import is_compiling
|
||||
from .mutation_guard import GenerationTracker
|
||||
from .pgo import CODE_STATE
|
||||
from .pgo import reset_code_state
|
||||
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
|
||||
|
||||
|
||||
|
|
@ -79,11 +79,24 @@ if torch.manual_seed is torch.random.manual_seed:
|
|||
|
||||
|
||||
def reset() -> None:
|
||||
"""Clear all compile caches and restore initial state"""
|
||||
"""
|
||||
Clear all compile caches and restore initial state. This function is intended
|
||||
to reset Dynamo's state *as if* you had started a fresh process invocation, which
|
||||
makes it good for testing scenarios where you want to behave as if you started
|
||||
a new process. It does NOT affect any file system caches.
|
||||
|
||||
NB: this does NOT reset logging state. Don't use this to test logging
|
||||
initialization/reinitialization.
|
||||
"""
|
||||
# TODO: https://github.com/pytorch/pytorch/issues/139200
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.info("torch._dynamo.reset")
|
||||
with convert_frame.compile_lock:
|
||||
reset_code_caches()
|
||||
convert_frame.input_codes.clear()
|
||||
CODE_STATE.clear()
|
||||
reset_code_state()
|
||||
convert_frame.output_codes.clear()
|
||||
orig_code_map.clear()
|
||||
guard_failures.clear()
|
||||
|
|
@ -102,9 +115,19 @@ def reset() -> None:
|
|||
|
||||
|
||||
def reset_code_caches() -> None:
|
||||
"""
|
||||
Clears in-memory code cache, which is what stores compiled products. This
|
||||
resets less state than :func:`reset` and is mostly only used for testing
|
||||
purposes.
|
||||
"""
|
||||
# TODO: https://github.com/pytorch/pytorch/issues/139200
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.info("torch._dynamo.reset_code_caches")
|
||||
"""Clear compile caches that are keyed by code objects"""
|
||||
with convert_frame.compile_lock:
|
||||
CODE_STATE.clear()
|
||||
reset_code_state()
|
||||
for weak_code in (
|
||||
convert_frame.input_codes.seen + convert_frame.output_codes.seen
|
||||
):
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Unio
|
|||
|
||||
import torch
|
||||
from torch._environment import is_fbcode
|
||||
from torch.utils._config_module import get_tristate_env, install_config_module
|
||||
|
||||
|
||||
# to configure logging for dynamo, aot, and inductor
|
||||
|
|
@ -485,6 +486,36 @@ compiled_autograd_kwargs_override: Dict[str, Any] = {}
|
|||
# NCCL timeout.
|
||||
enable_compiler_collectives = os.environ.get("TORCH_COMPILER_COLLECTIVES", "0") == "1"
|
||||
|
||||
# Enables a local, filesystem "profile" which can be used for automatic
|
||||
# dynamic decisions, analogous to profile-guided optimization. This config
|
||||
# ONLY has an effect if torch.compiler.config.workflow_id is specified,
|
||||
# which specifies the name of the profile we will save/load.
|
||||
#
|
||||
# The idea is that if we observe that a particular input is dynamic over
|
||||
# multiple iterations on one run, we can save a profile with this information
|
||||
# so the next time we run we can just make it dynamic the first time around,
|
||||
# skipping an unnecessary static compilation. The profile can be soundly
|
||||
# stale, if it is wrong, it just means we may make more things dynamic than
|
||||
# was actually necessary (NB: this /can/ cause a failure if making something
|
||||
# dynamic causes the compiler to stop working because you tickled a latent
|
||||
# bug.)
|
||||
#
|
||||
# The profile is ONLY guaranteed to work if the user source code is 100%
|
||||
# unchanged. Applying the profile if there are user code changes is only
|
||||
# best effort otherwise. In particular, we identify particular code objects
|
||||
# by filename, line number and name of their function, so adding/removing newlines
|
||||
# will typically cause cache misses. We continuously update the profile,
|
||||
# so if we only discover something is dynamic on the second run, we will update
|
||||
# the profile for subsequent runs.
|
||||
automatic_dynamic_local_pgo: bool = (
|
||||
os.environ.get("TORCH_DYNAMO_AUTOMATIC_DYNAMIC_LOCAL_PGO", "0") == "1"
|
||||
)
|
||||
|
||||
# Like above, but using remote cache
|
||||
automatic_dynamic_remote_pgo: Optional[bool] = get_tristate_env(
|
||||
"TORCH_DYNAMO_AUTOMATIC_DYNAMIC_REMOTE_PGO"
|
||||
)
|
||||
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
|
|
@ -495,7 +526,4 @@ if TYPE_CHECKING:
|
|||
...
|
||||
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ from .guards import (
|
|||
GuardedCode,
|
||||
)
|
||||
from .hooks import Hooks
|
||||
from .pgo import put_code_state
|
||||
from .replay_record import ExecutionRecord
|
||||
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
|
||||
from .symbolic_convert import (
|
||||
|
|
@ -1019,6 +1020,8 @@ def _compile(
|
|||
f"{type(e).__qualname__}: {str(e)}"
|
||||
).with_traceback(e.__traceback__) from None
|
||||
finally:
|
||||
put_code_state()
|
||||
|
||||
if tracer:
|
||||
tracer.output.local_scope = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,25 +1,43 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import dataclasses
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import DefaultDict, Optional, Tuple, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
from torch._dynamo.utils import get_chromium_event_logger
|
||||
import torch._dynamo.config
|
||||
import torch._utils_internal
|
||||
import torch.compiler.config
|
||||
import torch.distributed as dist
|
||||
from torch._dynamo.utils import dynamo_timed, get_chromium_event_logger, warn_once
|
||||
from torch._environment import is_fbcode
|
||||
from torch._inductor.remote_cache import create_cache
|
||||
from torch._logging._internal import trace_structured_artifact
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import types
|
||||
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._inductor.remote_cache import JsonDataTy, RemoteCache
|
||||
|
||||
|
||||
class ReservedWorkflowIdUserError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
LOCK_TIMEOUT = 10
|
||||
|
||||
# How does in memory representation work? Concretely, this module is
|
||||
# responsible for holding GLOBAL state representing the state it holds, no
|
||||
# other copies permitted. So we retire frame_state entirely and store it
|
||||
|
|
@ -29,6 +47,48 @@ log = logging.getLogger(__name__)
|
|||
# don't mind leaking it.
|
||||
|
||||
|
||||
# How exactly did we design the cache key? Here are some of the questions:
|
||||
#
|
||||
# - JOB_ID: Do we have a unique identifier for the "training run" (such that
|
||||
# it stays the same if we're running the same code, and changes if we're
|
||||
# running something different).
|
||||
#
|
||||
# - RANK: Are we sharing the cache across ranks, or does each rank get
|
||||
# an individual cache?
|
||||
#
|
||||
# We choose to require job_id for PGO cache. This is to prevent
|
||||
# situations where unrelated invocations of PyTorch unpredictably cause
|
||||
# changes to each other's behavior. With a job_id, at least you know there
|
||||
# is some "state" associated with it. (State dict might be another way to
|
||||
# tell if a run is related or not.) You can opt-in to YOLO everything
|
||||
# aliases everything by passing a shared job_id for all your invocations.
|
||||
#
|
||||
# We choose to NOT share PGO cache across ranks. With no RANK_SHARING, there
|
||||
# is never contention between runs, so we can leisurely update a bundle with
|
||||
# information we need. Because we are grouped by job_id, we can have a single
|
||||
# consolidated bundle for everything (or not; maybe worry about O(n^2) IO if
|
||||
# we updated every compile--let's just instrument this.) Can even take a
|
||||
# filelock for extra safety (expect no contention); expect 50ns overhead from
|
||||
# uncontended filelock.
|
||||
#
|
||||
# If we did share ranks, everyone is storming to modify the same cache files.
|
||||
# We can do this by having folks atomic write to a CAS-store and then having
|
||||
# readers do on-the-fly merging (this can be implemented in remote using
|
||||
# prefix iteration). As an optional optimization, one rank can be elected to
|
||||
# handling bundling post facto (ideally, this is done async, after quiescence,
|
||||
# without compiler collective need to wait for everyone to finish writing
|
||||
# their bits.) Not sure how you can avoid a listdir because if some rank shows
|
||||
# up with some new entries we need to pull them in ASAP (unless you want to
|
||||
# delay bundling).
|
||||
#
|
||||
# But compiler collectives fill a similar niche: compilers chat with each
|
||||
# other so rank 0 has collected everything. So elect rank 0 only to write the
|
||||
# bundle. Don't even need CAS-store atomic write; just one rank writing an
|
||||
# updating bundles. The point is that use compiler collectives to share
|
||||
# profiles across ranks, but use the PGO cache to persist profiles per rank
|
||||
# across attempts. No need to have one mechanism to do everything.
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CodeId:
|
||||
filename: str
|
||||
|
|
@ -47,7 +107,8 @@ class CodeState:
|
|||
)
|
||||
|
||||
|
||||
CODE_STATE: DefaultDict[CodeId, CodeState] = defaultdict(CodeState)
|
||||
_INIT_CODE_STATE: Optional[DefaultDict[CodeId, CodeState]] = None
|
||||
_CODE_STATE: Optional[DefaultDict[CodeId, CodeState]] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
|
@ -195,7 +256,7 @@ def update_automatic_dynamic(
|
|||
is_unspecialized_nn_module: bool = False,
|
||||
) -> FrameStateSizeEntry:
|
||||
code_id = CodeId.make(tx.f_code)
|
||||
frame_state = CODE_STATE[code_id]
|
||||
frame_state = get_code_state()[code_id]
|
||||
is_update = name in frame_state.automatic_dynamic
|
||||
mut_entry = frame_state.automatic_dynamic[name]
|
||||
old_entry = copy.copy(mut_entry)
|
||||
|
|
@ -332,3 +393,284 @@ def process_automatic_dynamic(
|
|||
)
|
||||
assert res is not None
|
||||
return res
|
||||
|
||||
|
||||
def get_cache_key() -> Optional[str]:
|
||||
# TODO: info versions of these logs that log only once
|
||||
if torch._inductor.config.force_disable_caches:
|
||||
warn_once(
|
||||
"dynamo_pgo force disabled by torch._inductor.config.force_disable_caches"
|
||||
)
|
||||
return None
|
||||
|
||||
# NB: We always use global rank for keys, even though they are overkill
|
||||
# for local only cache
|
||||
rank = None
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
|
||||
# NB: We namespace the cache keys so that only user-specified job id
|
||||
# can alias with each other.
|
||||
if (r := torch.compiler.config.job_id) is not None:
|
||||
if r.startswith("mast:"):
|
||||
raise ReservedWorkflowIdUserError(
|
||||
"torch.compiler.config.job_id with prefix 'mast:' is reserved for "
|
||||
"automatically generated job id associated with a specific MAST job "
|
||||
"name and version."
|
||||
)
|
||||
return f"{r}:{rank}"
|
||||
|
||||
if (name_version := torch._utils_internal.get_mast_job_name_version()) is not None:
|
||||
mast_job_name, mast_job_version = name_version
|
||||
return f"mast:{mast_job_name}:{mast_job_version}:{rank}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# This solely controls local PGO
|
||||
def code_state_path(cache_key: str) -> Optional[str]:
|
||||
if not torch._dynamo.config.automatic_dynamic_local_pgo:
|
||||
log.debug("automatic_dynamic_local_pgo not enabled")
|
||||
return None
|
||||
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
|
||||
return os.path.join(cache_dir(), "dynamo", f"code_state_{cache_key}.pkl")
|
||||
|
||||
|
||||
def should_use_remote_dynamo_pgo_cache() -> bool:
|
||||
if torch._inductor.config.force_disable_caches:
|
||||
return False
|
||||
|
||||
if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None:
|
||||
return r
|
||||
|
||||
if not is_fbcode():
|
||||
return False
|
||||
|
||||
if torch._utils_internal.is_fb_unit_test():
|
||||
return False
|
||||
|
||||
try:
|
||||
from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
|
||||
"pytorch/remote_cache:dynamo_pgo_version"
|
||||
)
|
||||
|
||||
|
||||
def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
|
||||
if not should_use_remote_dynamo_pgo_cache():
|
||||
return None
|
||||
|
||||
return create_cache(
|
||||
"dynamo-pgo",
|
||||
is_fbcode(),
|
||||
"FbRemoteDynamoPGOCache",
|
||||
"RemoteDynamoPGOCache",
|
||||
)
|
||||
|
||||
|
||||
# TODO: this dump format sucks but apparently it's very difficult to json.dumps
|
||||
# while not indenting inner lists SIGH
|
||||
|
||||
|
||||
def _key_asdict(x: object) -> object:
|
||||
if isinstance(x, CodeId):
|
||||
return f"{x.filename}:{x.firstlineno}:{x.name}"
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def _asdict(x: object) -> object:
|
||||
if isinstance(x, (dict, defaultdict)):
|
||||
return {_key_asdict(k): _asdict(v) for k, v in x.items()}
|
||||
elif isinstance(x, (list, tuple)):
|
||||
return [_asdict(v) for v in x]
|
||||
elif dataclasses.is_dataclass(x):
|
||||
return {
|
||||
field.name: _asdict(getattr(x, field.name))
|
||||
for field in dataclasses.fields(x)
|
||||
}
|
||||
elif x is auto_unset:
|
||||
return "auto_unset"
|
||||
elif x is auto_dynamic:
|
||||
return "auto_dynamic"
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def get_code_state() -> DefaultDict[CodeId, CodeState]:
|
||||
global _CODE_STATE, _INIT_CODE_STATE
|
||||
if _CODE_STATE is not None:
|
||||
return _CODE_STATE
|
||||
|
||||
chromium_log = get_chromium_event_logger()
|
||||
|
||||
# Initialize it (even if we don't look up profile)
|
||||
_CODE_STATE = defaultdict(CodeState)
|
||||
|
||||
cache_key = get_cache_key()
|
||||
if cache_key is None:
|
||||
return _CODE_STATE
|
||||
|
||||
def hit(ty: str) -> DefaultDict[CodeId, CodeState]:
|
||||
global _INIT_CODE_STATE
|
||||
assert isinstance(_CODE_STATE, defaultdict)
|
||||
log.info("get_code_state %s hit %s, %d entries", path, ty, len(_CODE_STATE))
|
||||
trace_structured_artifact(
|
||||
f"get_{ty}_code_state",
|
||||
"string",
|
||||
lambda: json.dumps(_asdict(_CODE_STATE), indent=1),
|
||||
)
|
||||
_INIT_CODE_STATE = copy.deepcopy(_CODE_STATE)
|
||||
return _CODE_STATE
|
||||
|
||||
# Attempt local
|
||||
path = code_state_path(cache_key)
|
||||
if path is not None and os.path.exists(path):
|
||||
with dynamo_timed(
|
||||
name := "pgo.get_local_code_state", log_pt2_compile_event=True
|
||||
):
|
||||
chromium_log.add_event_data(name, cache_key=cache_key)
|
||||
# Read lock not necessary as we always write atomically write to
|
||||
# the actual location
|
||||
with open(path, "rb") as f:
|
||||
try:
|
||||
_CODE_STATE = pickle.load(f)
|
||||
except Exception:
|
||||
log.warning(
|
||||
"get_code_state failed while reading %s", path, exc_info=True
|
||||
)
|
||||
else:
|
||||
return hit("local")
|
||||
|
||||
# Attempt remote
|
||||
remote_cache = get_remote_cache()
|
||||
if remote_cache is not None:
|
||||
with dynamo_timed(
|
||||
name := "pgo.get_remote_code_state", log_pt2_compile_event=True
|
||||
):
|
||||
chromium_log.add_event_data(name, cache_key=cache_key)
|
||||
# TODO: I don't really understand why there's a JSON container format
|
||||
try:
|
||||
cache_data = remote_cache.get(cache_key)
|
||||
except Exception:
|
||||
log.warning(
|
||||
"get_code_state failed remote read on %s", cache_key, exc_info=True
|
||||
)
|
||||
else:
|
||||
if cache_data is not None:
|
||||
try:
|
||||
assert isinstance(cache_data, dict)
|
||||
data = cache_data["data"]
|
||||
assert isinstance(data, str)
|
||||
payload = base64.b64decode(data)
|
||||
_CODE_STATE = pickle.loads(payload)
|
||||
except Exception:
|
||||
log.warning(
|
||||
"get_code_state failed parsing remote result on %s",
|
||||
cache_key,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
return hit("remote")
|
||||
else:
|
||||
log.info("get_code_state remote miss on %s", cache_key)
|
||||
|
||||
log.info("get_code_state using default")
|
||||
|
||||
assert _CODE_STATE is not None
|
||||
return _CODE_STATE
|
||||
|
||||
|
||||
def put_code_state() -> None:
|
||||
if _CODE_STATE is None:
|
||||
log.info("put_code_state: never initialized, will not write")
|
||||
return
|
||||
|
||||
if _CODE_STATE == _INIT_CODE_STATE:
|
||||
log.info("put_code_state: no change, skipping")
|
||||
return
|
||||
|
||||
cache_key = get_cache_key()
|
||||
if cache_key is None:
|
||||
log.info("put_code_state: no cache key, skipping")
|
||||
return
|
||||
|
||||
put_local_code_state(cache_key)
|
||||
put_remote_code_state(cache_key)
|
||||
|
||||
|
||||
def put_local_code_state(cache_key: str) -> None:
|
||||
with dynamo_timed(name := "pgo.put_local_code_state", log_pt2_compile_event=True):
|
||||
chromium_log = get_chromium_event_logger()
|
||||
chromium_log.add_event_data(name, cache_key=cache_key)
|
||||
assert _CODE_STATE is not None
|
||||
|
||||
path = code_state_path(cache_key)
|
||||
|
||||
if path is None:
|
||||
log.info("put_code_state: local cache disabled")
|
||||
return
|
||||
|
||||
# If the user isn't misusing our API, we should have exclusive access to
|
||||
# this directory. But it's not too hard
|
||||
|
||||
tmp_path = path + ".tmp"
|
||||
lock_path = path + ".lock"
|
||||
# We /mostly/ don't need the lock but the tmp file could be clobbered
|
||||
# TODO: use a safe tempfile create to eliminate lock
|
||||
from filelock import FileLock
|
||||
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
||||
with FileLock(lock_path, timeout=LOCK_TIMEOUT):
|
||||
with open(tmp_path, "wb") as f:
|
||||
pickle.dump(_CODE_STATE, f)
|
||||
os.rename(tmp_path, path)
|
||||
log.info(
|
||||
"put_code_state: wrote local %s, %d entries", path, len(_CODE_STATE)
|
||||
)
|
||||
trace_structured_artifact(
|
||||
"put_local_code_state",
|
||||
"string",
|
||||
lambda: json.dumps(_asdict(_CODE_STATE), indent=1),
|
||||
)
|
||||
|
||||
|
||||
def put_remote_code_state(cache_key: str) -> None:
|
||||
with dynamo_timed(name := "pgo.put_remote_code_state", log_pt2_compile_event=True):
|
||||
chromium_log = get_chromium_event_logger()
|
||||
chromium_log.add_event_data(name, cache_key=cache_key)
|
||||
assert _CODE_STATE is not None
|
||||
|
||||
remote_cache = get_remote_cache()
|
||||
|
||||
if remote_cache is None:
|
||||
log.info("put_code_state: remote cache disabled")
|
||||
return
|
||||
|
||||
content = pickle.dumps(_CODE_STATE)
|
||||
cache_data: JsonDataTy = {
|
||||
"data": base64.b64encode(content).decode("ascii"),
|
||||
}
|
||||
remote_cache.put(cache_key, cache_data)
|
||||
log.info(
|
||||
"put_code_state: wrote remote %s, %d entries", cache_key, len(_CODE_STATE)
|
||||
)
|
||||
# TODO: don't log this multiple times
|
||||
trace_structured_artifact(
|
||||
"put_remote_code_state",
|
||||
"string",
|
||||
lambda: json.dumps(_asdict(_CODE_STATE), indent=1),
|
||||
)
|
||||
|
||||
|
||||
# NB: this does NOT reset the cached code state on disk
|
||||
def reset_code_state() -> None:
|
||||
global _CODE_STATE, _INIT_CODE_STATE
|
||||
_CODE_STATE = None
|
||||
_INIT_CODE_STATE = None
|
||||
|
|
|
|||
|
|
@ -5,31 +5,23 @@ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
|
|||
import torch
|
||||
import torch._inductor.custom_graph_pass
|
||||
from torch._environment import is_fbcode
|
||||
|
||||
|
||||
def _get_tristate_env(name: str) -> Optional[bool]:
|
||||
value = os.environ.get(name)
|
||||
if value == "1":
|
||||
return True
|
||||
if value == "0":
|
||||
return False
|
||||
return None
|
||||
from torch.utils._config_module import get_tristate_env, install_config_module
|
||||
|
||||
|
||||
def fx_graph_remote_cache_default() -> Optional[bool]:
|
||||
return _get_tristate_env("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE")
|
||||
return get_tristate_env("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE")
|
||||
|
||||
|
||||
def autotune_remote_cache_default() -> Optional[bool]:
|
||||
return _get_tristate_env("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE")
|
||||
return get_tristate_env("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE")
|
||||
|
||||
|
||||
def bundled_autotune_remote_cache_default() -> Optional[bool]:
|
||||
return _get_tristate_env("TORCHINDUCTOR_BUNDLED_AUTOTUNE_REMOTE_CACHE")
|
||||
return get_tristate_env("TORCHINDUCTOR_BUNDLED_AUTOTUNE_REMOTE_CACHE")
|
||||
|
||||
|
||||
def bundle_triton_into_fx_graph_cache_default() -> Optional[bool]:
|
||||
return _get_tristate_env("TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE")
|
||||
return get_tristate_env("TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE")
|
||||
|
||||
|
||||
# Enable auto_functionalized_v2 (enabled by default)
|
||||
|
|
@ -1340,8 +1332,6 @@ class test_configs:
|
|||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
# adds patch, save_config, etc
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
|
|
|||
|
|
@ -306,6 +306,10 @@ class RemoteAOTAutogradCache(RedisRemoteCache):
|
|||
pass
|
||||
|
||||
|
||||
class RemoteDynamoPGOCache(RedisRemoteCache):
|
||||
pass
|
||||
|
||||
|
||||
def create_cache(
|
||||
key: str,
|
||||
is_fbcode: bool,
|
||||
|
|
|
|||
|
|
@ -1130,6 +1130,21 @@ def get_structured_logging_overhead() -> Optional[float]:
|
|||
return None
|
||||
|
||||
|
||||
def trace_structured_artifact(
|
||||
name: str, # this will go in metadata
|
||||
encoding: str,
|
||||
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
|
||||
) -> None:
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": name,
|
||||
"encoding": encoding,
|
||||
},
|
||||
payload_fn=payload_fn,
|
||||
)
|
||||
|
||||
|
||||
def trace_structured(
|
||||
name: str,
|
||||
# NB: metadata expected to be dict so adding more info is forward compatible
|
||||
|
|
@ -1140,7 +1155,7 @@ def trace_structured(
|
|||
suppress_context: bool = False,
|
||||
expect_trace_id: bool = True, # Whether or not we expect to have a current trace id
|
||||
record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
metadata is an arbitrary JSON compatible struct, but it's expected to not be
|
||||
too long (e.g., less than 1MB)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
|
|
@ -344,6 +344,10 @@ def max_clock_rate():
|
|||
return 1100
|
||||
|
||||
|
||||
def get_mast_job_name_version() -> Optional[Tuple[str, int]]:
|
||||
return None
|
||||
|
||||
|
||||
TEST_MASTER_ADDR = "127.0.0.1"
|
||||
TEST_MASTER_PORT = 29500
|
||||
# USE_GLOBAL_DEPS controls whether __init__.py tries to load
|
||||
|
|
|
|||
62
torch/compiler/config.py
Normal file
62
torch/compiler/config.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""
|
||||
This is the top-level configuration module for the compiler, containing
|
||||
cross-cutting configuration options that affect all parts of the compiler
|
||||
stack.
|
||||
|
||||
You may also be interested in the per-component configuration modules, which
|
||||
contain configuration options that affect only a specific part of the compiler:
|
||||
|
||||
* :mod:`torch._dynamo.config`
|
||||
* :mod:`torch._inductor.config`
|
||||
* :mod:`torch._functorch.config`
|
||||
* :mod:`torch.fx.experimental.config`
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
|
||||
__all__ = [
|
||||
"job_id",
|
||||
]
|
||||
|
||||
|
||||
# NB: Docblocks go UNDER variable definitions! Use spacing to make the
|
||||
# grouping clear.
|
||||
|
||||
# FB-internal note: you do NOT have to specify this explicitly specify this if
|
||||
# you run on MAST, we will automatically default this to
|
||||
# mast:MAST_JOB_NAME:MAST_JOB_VERSION.
|
||||
job_id: Optional[str] = os.environ.get("TORCH_COMPILE_JOB_ID", None)
|
||||
"""
|
||||
Semantically, this should be an identifier that uniquely identifies, e.g., a
|
||||
training job. You might have multiple attempts of the same job, e.g., if it was
|
||||
preempted or needed to be restarted, but each attempt should be running
|
||||
substantially the same workload with the same distributed topology. You can
|
||||
set this by environment variable with :envvar:`TORCH_COMPILE_JOB_ID`.
|
||||
|
||||
Operationally, this controls the effect of profile-guided optimization related
|
||||
persistent state. PGO state can affect how we perform compilation across
|
||||
multiple invocations of PyTorch, e.g., the first time you run your program we
|
||||
may compile twice as we discover what inputs are dynamic, and then PGO will
|
||||
save this state so subsequent invocations only need to compile once, because
|
||||
they remember it is dynamic. This profile information, however, is sensitive
|
||||
to what workload you are running, so we require you to tell us that two jobs
|
||||
are *related* (i.e., are the same workload) before we are willing to reuse
|
||||
this information. Notably, PGO does nothing (even if explicitly enabled)
|
||||
unless a valid ``job_id`` is available. In some situations, PyTorch can
|
||||
configured to automatically compute a ``job_id`` based on the environment it
|
||||
is running in.
|
||||
|
||||
Profiles are always collected on a per rank basis, so different ranks may have
|
||||
different profiles. If you know your workload is truly SPMD, you can run with
|
||||
:data:`torch._dynamo.config.enable_compiler_collectives` to ensure nodes get
|
||||
consistent profiles across all ranks.
|
||||
"""
|
||||
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
|
@ -3,6 +3,7 @@ import copy
|
|||
import hashlib
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import tokenize
|
||||
import unittest
|
||||
|
|
@ -60,7 +61,8 @@ def install_config_module(module: ModuleType) -> None:
|
|||
"""
|
||||
|
||||
class ConfigModuleInstance(ConfigModule):
|
||||
_bypass_keys = set({"_is_dirty", "_hash_digest"})
|
||||
# __annotations__ is written to by Sphinx autodoc
|
||||
_bypass_keys = set({"_is_dirty", "_hash_digest", "__annotations__"})
|
||||
|
||||
def visit(
|
||||
source: Union[ModuleType, type],
|
||||
|
|
@ -496,3 +498,12 @@ def patch_object(obj: object, name: str, value: object) -> object:
|
|||
if isinstance(obj, ConfigModule):
|
||||
return obj.patch(name, value)
|
||||
return mock.patch.object(obj, name, value)
|
||||
|
||||
|
||||
def get_tristate_env(name: str) -> Optional[bool]:
|
||||
value = os.environ.get(name)
|
||||
if value == "1":
|
||||
return True
|
||||
if value == "0":
|
||||
return False
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user