diff --git a/docs/source/torch.compiler.config.rst b/docs/source/torch.compiler.config.rst new file mode 100644 index 00000000000..c40b41fdb5d --- /dev/null +++ b/docs/source/torch.compiler.config.rst @@ -0,0 +1,9 @@ +.. currentmodule:: torch.compiler.config + + +torch.compiler.config +===================== + +.. automodule:: torch.compiler.config + +.. autodata:: torch.compiler.config.job_id diff --git a/docs/source/torch.compiler.rst b/docs/source/torch.compiler.rst index c2c457c0b07..7f5e854f0a6 100644 --- a/docs/source/torch.compiler.rst +++ b/docs/source/torch.compiler.rst @@ -85,6 +85,7 @@ Read More torch.compiler_get_started torch.compiler_api + torch.compiler.config torch.compiler_fine_grain_apis torch.compiler_aot_inductor torch.compiler_inductor_profiling diff --git a/test/dynamo/test_pgo.py b/test/dynamo/test_pgo.py new file mode 100644 index 00000000000..35069078253 --- /dev/null +++ b/test/dynamo/test_pgo.py @@ -0,0 +1,126 @@ +# Owner(s): ["module: dynamo"] + +import contextlib +import importlib +import os +import sys + +import torch._dynamo.config +import torch._dynamo.test_case +import torch.compiler.config +from torch._dynamo.testing import CompileCounter +from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache + + +# LOL. https://github.com/pytorch/pytorch/issues/139252 +spec = importlib.util.spec_from_file_location( + "mock_cache", os.path.join(os.path.dirname(__file__), "../inductor/mock_cache.py") +) +mock_cache = importlib.util.module_from_spec(spec) +sys.modules["mock_cache"] = mock_cache +spec.loader.exec_module(mock_cache) + + +class PgoTest(torch._dynamo.test_case.TestCase): + def setUp(self): + super().setUp() + self._test_stack = contextlib.ExitStack() + self._test_stack.enter_context(torch.compiler.config.patch(job_id=self.id())) + self._test_stack.enter_context( + torch._dynamo.config.patch(automatic_dynamic_local_pgo=True) + ) + if os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1": + self._test_stack.enter_context(fresh_inductor_cache()) + mock_cache.PatchCaches.setUp() + + def tearDown(self): + super().tearDown() + torch._dynamo.reset() + self._test_stack.close() + mock_cache.PatchCaches.tearDown() + + def reset(self): + torch._dynamo.reset() + clear_inductor_caches() + + def test_basic(self): + cnts = CompileCounter() + + @torch.compile(backend=cnts, fullgraph=True) + def f(x): + return x * 2 + + f(torch.randn(2, 3)) + f(torch.randn(2, 4)) + self.assertEqual(cnts.frame_count, 2) + + self.reset() + cnts.clear() + + f(torch.randn(2, 5)) + f(torch.randn(2, 6)) + self.assertEqual(cnts.frame_count, 1) + + def test_distinct_compile_id(self): + cnts = CompileCounter() + + @torch.compile(backend=cnts, fullgraph=True) + def f(x): + return x * 2 + + with torch.compiler.config.patch(job_id="foo"): + f(torch.randn(2, 3)) + f(torch.randn(2, 4)) + self.assertEqual(cnts.frame_count, 2) + + self.reset() + cnts.clear() + + with torch.compiler.config.patch(job_id="bar"): + f(torch.randn(2, 5)) + f(torch.randn(2, 6)) + self.assertEqual(cnts.frame_count, 2) + + torch._dynamo.reset() + clear_inductor_caches() + cnts.clear() + + with torch.compiler.config.patch(job_id="foo"): + f(torch.randn(2, 7)) + f(torch.randn(2, 8)) + self.assertEqual(cnts.frame_count, 1) + + # TODO: to test local need to ensure the local filesystem gets cleared out + @torch._dynamo.config.patch( + automatic_dynamic_remote_pgo=True, automatic_dynamic_local_pgo=False + ) + def test_remote_basic(self): + cnts = CompileCounter() + + @torch.compile(backend=cnts, fullgraph=True) + def f(x): + return x * 2 + + with mock_cache.PatchCaches(): + f(torch.randn(2, 3)) + f(torch.randn(2, 4)) + self.assertEqual(cnts.frame_count, 2) + self.assertEqual( + mock_cache.global_stats.dynamo_pgo, mock_cache.Stats(2, 0, 1) + ) + + self.reset() + cnts.clear() + + f(torch.randn(2, 5)) + f(torch.randn(2, 6)) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual( + mock_cache.global_stats.dynamo_pgo, mock_cache.Stats(2, 1, 1) + ) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/inductor/mock_cache.py b/test/inductor/mock_cache.py index 9c6d0ad7736..e0aa6c61a3a 100644 --- a/test/inductor/mock_cache.py +++ b/test/inductor/mock_cache.py @@ -80,6 +80,7 @@ class _GlobalStats(threading.local): self.fx_graph = _GlobalItemStats() self.triton = _GlobalItemStats() self.aot_autograd = _GlobalItemStats() + self.dynamo_pgo = _GlobalItemStats() def reset(self) -> None: self.autotune_local.reset() @@ -88,6 +89,7 @@ class _GlobalStats(threading.local): self.fx_graph.reset() self.triton.reset() self.aot_autograd.reset() + self.dynamo_pgo.reset() def get_stat(self, name: str) -> _GlobalItemStats: return getattr(self, name) @@ -100,6 +102,7 @@ class _GlobalStats(threading.local): ("fx_graph", self.fx_graph), ("triton", self.triton), ("aot_autograd", self.aot_autograd), + ("dynamo_pgo", self.dynamo_pgo), ) print("Cache Stats:", file=sys.stderr) @@ -215,6 +218,12 @@ class PatchCaches(contextlib.AbstractContextManager): ) self._stack.enter_context(ctx) + ctx = patch( + "torch._inductor.remote_cache.RemoteDynamoPGOCache.backend_override_cls", + MockBackend.with_name("dynamo_pgo"), + ) + self._stack.enter_context(ctx) + if config.is_fbcode(): ctx = patch( "torch._inductor.fb.remote_cache.FbRemoteAutotuneCache.backend_override_cls", @@ -246,6 +255,12 @@ class PatchCaches(contextlib.AbstractContextManager): ) self._stack.enter_context(ctx) + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteDynamoPGOCache.backend_override_cls", + MockBackend.with_name("dynamo_pgo"), + ) + self._stack.enter_context(ctx) + return self def __exit__( diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index c3197d75109..986e5dd0900 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -33,7 +33,7 @@ from .eval_frame import ( ) from .external_utils import is_compiling from .mutation_guard import GenerationTracker -from .pgo import CODE_STATE +from .pgo import reset_code_state from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count @@ -79,11 +79,24 @@ if torch.manual_seed is torch.random.manual_seed: def reset() -> None: - """Clear all compile caches and restore initial state""" + """ + Clear all compile caches and restore initial state. This function is intended + to reset Dynamo's state *as if* you had started a fresh process invocation, which + makes it good for testing scenarios where you want to behave as if you started + a new process. It does NOT affect any file system caches. + + NB: this does NOT reset logging state. Don't use this to test logging + initialization/reinitialization. + """ + # TODO: https://github.com/pytorch/pytorch/issues/139200 + import logging + + log = logging.getLogger(__name__) + log.info("torch._dynamo.reset") with convert_frame.compile_lock: reset_code_caches() convert_frame.input_codes.clear() - CODE_STATE.clear() + reset_code_state() convert_frame.output_codes.clear() orig_code_map.clear() guard_failures.clear() @@ -102,9 +115,19 @@ def reset() -> None: def reset_code_caches() -> None: + """ + Clears in-memory code cache, which is what stores compiled products. This + resets less state than :func:`reset` and is mostly only used for testing + purposes. + """ + # TODO: https://github.com/pytorch/pytorch/issues/139200 + import logging + + log = logging.getLogger(__name__) + log.info("torch._dynamo.reset_code_caches") """Clear compile caches that are keyed by code objects""" with convert_frame.compile_lock: - CODE_STATE.clear() + reset_code_state() for weak_code in ( convert_frame.input_codes.seen + convert_frame.output_codes.seen ): diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 8b2759d181c..5c36654ae5d 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Unio import torch from torch._environment import is_fbcode +from torch.utils._config_module import get_tristate_env, install_config_module # to configure logging for dynamo, aot, and inductor @@ -485,6 +486,36 @@ compiled_autograd_kwargs_override: Dict[str, Any] = {} # NCCL timeout. enable_compiler_collectives = os.environ.get("TORCH_COMPILER_COLLECTIVES", "0") == "1" +# Enables a local, filesystem "profile" which can be used for automatic +# dynamic decisions, analogous to profile-guided optimization. This config +# ONLY has an effect if torch.compiler.config.workflow_id is specified, +# which specifies the name of the profile we will save/load. +# +# The idea is that if we observe that a particular input is dynamic over +# multiple iterations on one run, we can save a profile with this information +# so the next time we run we can just make it dynamic the first time around, +# skipping an unnecessary static compilation. The profile can be soundly +# stale, if it is wrong, it just means we may make more things dynamic than +# was actually necessary (NB: this /can/ cause a failure if making something +# dynamic causes the compiler to stop working because you tickled a latent +# bug.) +# +# The profile is ONLY guaranteed to work if the user source code is 100% +# unchanged. Applying the profile if there are user code changes is only +# best effort otherwise. In particular, we identify particular code objects +# by filename, line number and name of their function, so adding/removing newlines +# will typically cause cache misses. We continuously update the profile, +# so if we only discover something is dynamic on the second run, we will update +# the profile for subsequent runs. +automatic_dynamic_local_pgo: bool = ( + os.environ.get("TORCH_DYNAMO_AUTOMATIC_DYNAMIC_LOCAL_PGO", "0") == "1" +) + +# Like above, but using remote cache +automatic_dynamic_remote_pgo: Optional[bool] = get_tristate_env( + "TORCH_DYNAMO_AUTOMATIC_DYNAMIC_REMOTE_PGO" +) + # HACK: this is for testing custom ops profiling only _custom_ops_profile: Optional[Any] = None @@ -495,7 +526,4 @@ if TYPE_CHECKING: ... -from torch.utils._config_module import install_config_module - - install_config_module(sys.modules[__name__]) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 912cff7d9d1..847cbf7e725 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -93,6 +93,7 @@ from .guards import ( GuardedCode, ) from .hooks import Hooks +from .pgo import put_code_state from .replay_record import ExecutionRecord from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX from .symbolic_convert import ( @@ -1019,6 +1020,8 @@ def _compile( f"{type(e).__qualname__}: {str(e)}" ).with_traceback(e.__traceback__) from None finally: + put_code_state() + if tracer: tracer.output.local_scope = {} diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index f43e87d93f0..300f9a58976 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -1,25 +1,42 @@ from __future__ import annotations +import base64 import copy import dataclasses import enum +import json import logging +import os +import pickle import time from collections import defaultdict from typing import DefaultDict, Optional, Tuple, TYPE_CHECKING, TypeVar, Union from typing_extensions import Self -from torch._dynamo.utils import get_chromium_event_logger +import torch._dynamo.config +import torch._utils_internal +import torch.compiler.config +import torch.distributed as dist +from torch._dynamo.utils import dynamo_timed, get_chromium_event_logger, warn_once +from torch._environment import is_fbcode +from torch._logging._internal import trace_structured_artifact if TYPE_CHECKING: import types from torch._dynamo.symbolic_convert import InstructionTranslator + from torch._inductor.remote_cache import JsonDataTy, RemoteCache + + +class ReservedWorkflowIdUserError(ValueError): + pass log = logging.getLogger(__name__) +LOCK_TIMEOUT = 10 + # How does in memory representation work? Concretely, this module is # responsible for holding GLOBAL state representing the state it holds, no # other copies permitted. So we retire frame_state entirely and store it @@ -29,6 +46,48 @@ log = logging.getLogger(__name__) # don't mind leaking it. +# How exactly did we design the cache key? Here are some of the questions: +# +# - JOB_ID: Do we have a unique identifier for the "training run" (such that +# it stays the same if we're running the same code, and changes if we're +# running something different). +# +# - RANK: Are we sharing the cache across ranks, or does each rank get +# an individual cache? +# +# We choose to require job_id for PGO cache. This is to prevent +# situations where unrelated invocations of PyTorch unpredictably cause +# changes to each other's behavior. With a job_id, at least you know there +# is some "state" associated with it. (State dict might be another way to +# tell if a run is related or not.) You can opt-in to YOLO everything +# aliases everything by passing a shared job_id for all your invocations. +# +# We choose to NOT share PGO cache across ranks. With no RANK_SHARING, there +# is never contention between runs, so we can leisurely update a bundle with +# information we need. Because we are grouped by job_id, we can have a single +# consolidated bundle for everything (or not; maybe worry about O(n^2) IO if +# we updated every compile--let's just instrument this.) Can even take a +# filelock for extra safety (expect no contention); expect 50ns overhead from +# uncontended filelock. +# +# If we did share ranks, everyone is storming to modify the same cache files. +# We can do this by having folks atomic write to a CAS-store and then having +# readers do on-the-fly merging (this can be implemented in remote using +# prefix iteration). As an optional optimization, one rank can be elected to +# handling bundling post facto (ideally, this is done async, after quiescence, +# without compiler collective need to wait for everyone to finish writing +# their bits.) Not sure how you can avoid a listdir because if some rank shows +# up with some new entries we need to pull them in ASAP (unless you want to +# delay bundling). +# +# But compiler collectives fill a similar niche: compilers chat with each +# other so rank 0 has collected everything. So elect rank 0 only to write the +# bundle. Don't even need CAS-store atomic write; just one rank writing an +# updating bundles. The point is that use compiler collectives to share +# profiles across ranks, but use the PGO cache to persist profiles per rank +# across attempts. No need to have one mechanism to do everything. + + @dataclasses.dataclass(frozen=True) class CodeId: filename: str @@ -47,7 +106,8 @@ class CodeState: ) -CODE_STATE: DefaultDict[CodeId, CodeState] = defaultdict(CodeState) +_INIT_CODE_STATE: Optional[DefaultDict[CodeId, CodeState]] = None +_CODE_STATE: Optional[DefaultDict[CodeId, CodeState]] = None @dataclasses.dataclass(frozen=True) @@ -195,7 +255,7 @@ def update_automatic_dynamic( is_unspecialized_nn_module: bool = False, ) -> FrameStateSizeEntry: code_id = CodeId.make(tx.f_code) - frame_state = CODE_STATE[code_id] + frame_state = get_code_state()[code_id] is_update = name in frame_state.automatic_dynamic mut_entry = frame_state.automatic_dynamic[name] old_entry = copy.copy(mut_entry) @@ -332,3 +392,286 @@ def process_automatic_dynamic( ) assert res is not None return res + + +def get_cache_key() -> Optional[str]: + # TODO: info versions of these logs that log only once + if torch._inductor.config.force_disable_caches: + warn_once( + "dynamo_pgo force disabled by torch._inductor.config.force_disable_caches" + ) + return None + + # NB: We always use global rank for keys, even though they are overkill + # for local only cache + rank = None + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + + # NB: We namespace the cache keys so that only user-specified job id + # can alias with each other. + if (r := torch.compiler.config.job_id) is not None: + if r.startswith("mast:"): + raise ReservedWorkflowIdUserError( + "torch.compiler.config.job_id with prefix 'mast:' is reserved for " + "automatically generated job id associated with a specific MAST job " + "name and version." + ) + return f"{r}:{rank}" + + if (name_version := torch._utils_internal.get_mast_job_name_version()) is not None: + mast_job_name, mast_job_version = name_version + return f"mast:{mast_job_name}:{mast_job_version}:{rank}" + + return None + + +# This solely controls local PGO +def code_state_path(cache_key: str) -> Optional[str]: + if not torch._dynamo.config.automatic_dynamic_local_pgo: + log.debug("automatic_dynamic_local_pgo not enabled") + return None + + from torch._inductor.runtime.runtime_utils import cache_dir + + return os.path.join(cache_dir(), "dynamo", f"code_state_{cache_key}.pkl") + + +def should_use_remote_dynamo_pgo_cache() -> bool: + if torch._inductor.config.force_disable_caches: + return False + + if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None: + return r + + if not is_fbcode(): + return False + + if torch._utils_internal.is_fb_unit_test(): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:dynamo_pgo_version" + ) + + +def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: + from torch._inductor.remote_cache import create_cache + + if not should_use_remote_dynamo_pgo_cache(): + return None + + return create_cache( + "dynamo-pgo", + is_fbcode(), + "FbRemoteDynamoPGOCache", + "RemoteDynamoPGOCache", + ) + + +# TODO: this dump format sucks but apparently it's very difficult to json.dumps +# while not indenting inner lists SIGH + + +def _key_asdict(x: object) -> object: + if isinstance(x, CodeId): + return f"{x.filename}:{x.firstlineno}:{x.name}" + else: + return x + + +def _asdict(x: object) -> object: + if isinstance(x, (dict, defaultdict)): + return {_key_asdict(k): _asdict(v) for k, v in x.items()} + elif isinstance(x, (list, tuple)): + return [_asdict(v) for v in x] + elif dataclasses.is_dataclass(x): + return { + field.name: _asdict(getattr(x, field.name)) + for field in dataclasses.fields(x) + } + elif x is auto_unset: + return "auto_unset" + elif x is auto_dynamic: + return "auto_dynamic" + else: + return x + + +def get_code_state() -> DefaultDict[CodeId, CodeState]: + global _CODE_STATE, _INIT_CODE_STATE + if _CODE_STATE is not None: + return _CODE_STATE + + chromium_log = get_chromium_event_logger() + + # Initialize it (even if we don't look up profile) + _CODE_STATE = defaultdict(CodeState) + + cache_key = get_cache_key() + if cache_key is None: + return _CODE_STATE + + def hit(ty: str) -> DefaultDict[CodeId, CodeState]: + global _INIT_CODE_STATE + assert isinstance(_CODE_STATE, defaultdict) + log.info("get_code_state %s hit %s, %d entries", path, ty, len(_CODE_STATE)) + trace_structured_artifact( + f"get_{ty}_code_state", + "string", + lambda: json.dumps(_asdict(_CODE_STATE), indent=1), + ) + _INIT_CODE_STATE = copy.deepcopy(_CODE_STATE) + return _CODE_STATE + + # Attempt local + path = code_state_path(cache_key) + if path is not None and os.path.exists(path): + with dynamo_timed( + name := "pgo.get_local_code_state", log_pt2_compile_event=True + ): + chromium_log.add_event_data(name, cache_key=cache_key) + # Read lock not necessary as we always write atomically write to + # the actual location + with open(path, "rb") as f: + try: + _CODE_STATE = pickle.load(f) + except Exception: + log.warning( + "get_code_state failed while reading %s", path, exc_info=True + ) + else: + return hit("local") + + # Attempt remote + remote_cache = get_remote_cache() + if remote_cache is not None: + with dynamo_timed( + name := "pgo.get_remote_code_state", log_pt2_compile_event=True + ): + chromium_log.add_event_data(name, cache_key=cache_key) + # TODO: I don't really understand why there's a JSON container format + try: + cache_data = remote_cache.get(cache_key) + except Exception: + log.warning( + "get_code_state failed remote read on %s", cache_key, exc_info=True + ) + else: + if cache_data is not None: + try: + assert isinstance(cache_data, dict) + data = cache_data["data"] + assert isinstance(data, str) + payload = base64.b64decode(data) + _CODE_STATE = pickle.loads(payload) + except Exception: + log.warning( + "get_code_state failed parsing remote result on %s", + cache_key, + exc_info=True, + ) + else: + return hit("remote") + else: + log.info("get_code_state remote miss on %s", cache_key) + + log.info("get_code_state using default") + + assert _CODE_STATE is not None + return _CODE_STATE + + +def put_code_state() -> None: + if _CODE_STATE is None: + log.info("put_code_state: never initialized, will not write") + return + + if _CODE_STATE == _INIT_CODE_STATE: + log.info("put_code_state: no change, skipping") + return + + cache_key = get_cache_key() + if cache_key is None: + log.info("put_code_state: no cache key, skipping") + return + + put_local_code_state(cache_key) + put_remote_code_state(cache_key) + + +def put_local_code_state(cache_key: str) -> None: + with dynamo_timed(name := "pgo.put_local_code_state", log_pt2_compile_event=True): + chromium_log = get_chromium_event_logger() + chromium_log.add_event_data(name, cache_key=cache_key) + assert _CODE_STATE is not None + + path = code_state_path(cache_key) + + if path is None: + log.info("put_code_state: local cache disabled") + return + + # If the user isn't misusing our API, we should have exclusive access to + # this directory. But it's not too hard + + tmp_path = path + ".tmp" + lock_path = path + ".lock" + # We /mostly/ don't need the lock but the tmp file could be clobbered + # TODO: use a safe tempfile create to eliminate lock + from filelock import FileLock + + os.makedirs(os.path.dirname(path), exist_ok=True) + + with FileLock(lock_path, timeout=LOCK_TIMEOUT): + with open(tmp_path, "wb") as f: + pickle.dump(_CODE_STATE, f) + os.rename(tmp_path, path) + log.info( + "put_code_state: wrote local %s, %d entries", path, len(_CODE_STATE) + ) + trace_structured_artifact( + "put_local_code_state", + "string", + lambda: json.dumps(_asdict(_CODE_STATE), indent=1), + ) + + +def put_remote_code_state(cache_key: str) -> None: + with dynamo_timed(name := "pgo.put_remote_code_state", log_pt2_compile_event=True): + chromium_log = get_chromium_event_logger() + chromium_log.add_event_data(name, cache_key=cache_key) + assert _CODE_STATE is not None + + remote_cache = get_remote_cache() + + if remote_cache is None: + log.info("put_code_state: remote cache disabled") + return + + content = pickle.dumps(_CODE_STATE) + cache_data: JsonDataTy = { + "data": base64.b64encode(content).decode("ascii"), + } + remote_cache.put(cache_key, cache_data) + log.info( + "put_code_state: wrote remote %s, %d entries", cache_key, len(_CODE_STATE) + ) + # TODO: don't log this multiple times + trace_structured_artifact( + "put_remote_code_state", + "string", + lambda: json.dumps(_asdict(_CODE_STATE), indent=1), + ) + + +# NB: this does NOT reset the cached code state on disk +def reset_code_state() -> None: + global _CODE_STATE, _INIT_CODE_STATE + _CODE_STATE = None + _INIT_CODE_STATE = None diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index ef12f16a1f4..e9d32251963 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -5,31 +5,23 @@ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union import torch import torch._inductor.custom_graph_pass from torch._environment import is_fbcode - - -def _get_tristate_env(name: str) -> Optional[bool]: - value = os.environ.get(name) - if value == "1": - return True - if value == "0": - return False - return None +from torch.utils._config_module import get_tristate_env, install_config_module def fx_graph_remote_cache_default() -> Optional[bool]: - return _get_tristate_env("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") + return get_tristate_env("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") def autotune_remote_cache_default() -> Optional[bool]: - return _get_tristate_env("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") + return get_tristate_env("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") def bundled_autotune_remote_cache_default() -> Optional[bool]: - return _get_tristate_env("TORCHINDUCTOR_BUNDLED_AUTOTUNE_REMOTE_CACHE") + return get_tristate_env("TORCHINDUCTOR_BUNDLED_AUTOTUNE_REMOTE_CACHE") def bundle_triton_into_fx_graph_cache_default() -> Optional[bool]: - return _get_tristate_env("TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE") + return get_tristate_env("TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE") # Enable auto_functionalized_v2 (enabled by default) @@ -1326,8 +1318,6 @@ class test_configs: if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 -from torch.utils._config_module import install_config_module - # adds patch, save_config, etc install_config_module(sys.modules[__name__]) diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 8f647e6827d..9886ff83b19 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -306,6 +306,10 @@ class RemoteAOTAutogradCache(RedisRemoteCache): pass +class RemoteDynamoPGOCache(RedisRemoteCache): + pass + + def create_cache( key: str, is_fbcode: bool, diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index fcf9caecc13..70bbb27bfa2 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1130,6 +1130,21 @@ def get_structured_logging_overhead() -> Optional[float]: return None +def trace_structured_artifact( + name: str, # this will go in metadata + encoding: str, + payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, +) -> None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": name, + "encoding": encoding, + }, + payload_fn=payload_fn, + ) + + def trace_structured( name: str, # NB: metadata expected to be dict so adding more info is forward compatible @@ -1140,7 +1155,7 @@ def trace_structured( suppress_context: bool = False, expect_trace_id: bool = True, # Whether or not we expect to have a current trace id record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging -): +) -> None: """ metadata is an arbitrary JSON compatible struct, but it's expected to not be too long (e.g., less than 1MB) diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 81b24f16235..9a7b99000c4 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -4,7 +4,7 @@ import logging import os import sys import tempfile -from typing import Any, Callable, Dict, List, Optional, TypeVar +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar from typing_extensions import ParamSpec import torch @@ -344,6 +344,10 @@ def max_clock_rate(): return 1100 +def get_mast_job_name_version() -> Optional[Tuple[str, int]]: + return None + + TEST_MASTER_ADDR = "127.0.0.1" TEST_MASTER_PORT = 29500 # USE_GLOBAL_DEPS controls whether __init__.py tries to load diff --git a/torch/compiler/config.py b/torch/compiler/config.py new file mode 100644 index 00000000000..9485b34fac2 --- /dev/null +++ b/torch/compiler/config.py @@ -0,0 +1,62 @@ +""" +This is the top-level configuration module for the compiler, containing +cross-cutting configuration options that affect all parts of the compiler +stack. + +You may also be interested in the per-component configuration modules, which +contain configuration options that affect only a specific part of the compiler: + +* :mod:`torch._dynamo.config` +* :mod:`torch._inductor.config` +* :mod:`torch._functorch.config` +* :mod:`torch.fx.experimental.config` +""" + +import os +import sys +from typing import Optional + + +__all__ = [ + "job_id", +] + + +# NB: Docblocks go UNDER variable definitions! Use spacing to make the +# grouping clear. + +# FB-internal note: you do NOT have to specify this explicitly specify this if +# you run on MAST, we will automatically default this to +# mast:MAST_JOB_NAME:MAST_JOB_VERSION. +job_id: Optional[str] = os.environ.get("TORCH_COMPILE_JOB_ID", None) +""" +Semantically, this should be an identifier that uniquely identifies, e.g., a +training job. You might have multiple attempts of the same job, e.g., if it was +preempted or needed to be restarted, but each attempt should be running +substantially the same workload with the same distributed topology. You can +set this by environment variable with :envvar:`TORCH_COMPILE_JOB_ID`. + +Operationally, this controls the effect of profile-guided optimization related +persistent state. PGO state can affect how we perform compilation across +multiple invocations of PyTorch, e.g., the first time you run your program we +may compile twice as we discover what inputs are dynamic, and then PGO will +save this state so subsequent invocations only need to compile once, because +they remember it is dynamic. This profile information, however, is sensitive +to what workload you are running, so we require you to tell us that two jobs +are *related* (i.e., are the same workload) before we are willing to reuse +this information. Notably, PGO does nothing (even if explicitly enabled) +unless a valid ``job_id`` is available. In some situations, PyTorch can +configured to automatically compute a ``job_id`` based on the environment it +is running in. + +Profiles are always collected on a per rank basis, so different ranks may have +different profiles. If you know your workload is truly SPMD, you can run with +:data:`torch._dynamo.config.enable_compiler_collectives` to ensure nodes get +consistent profiles across all ranks. +""" + + +from torch.utils._config_module import install_config_module + + +install_config_module(sys.modules[__name__]) diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 421e7c68032..c2b33017f6b 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -3,6 +3,7 @@ import copy import hashlib import inspect import io +import os import pickle import tokenize import unittest @@ -60,7 +61,8 @@ def install_config_module(module: ModuleType) -> None: """ class ConfigModuleInstance(ConfigModule): - _bypass_keys = set({"_is_dirty", "_hash_digest"}) + # __annotations__ is written to by Sphinx autodoc + _bypass_keys = set({"_is_dirty", "_hash_digest", "__annotations__"}) def visit( source: Union[ModuleType, type], @@ -496,3 +498,12 @@ def patch_object(obj: object, name: str, value: object) -> object: if isinstance(obj, ConfigModule): return obj.patch(name, value) return mock.patch.object(obj, name, value) + + +def get_tristate_env(name: str) -> Optional[bool]: + value = os.environ.get(name) + if value == "1": + return True + if value == "0": + return False + return None