Profile guided optimization for automatic_dynamic (#139001)

Previously: https://github.com/pytorch/pytorch/pull/138052 but the implementation is done from scratch, so I open a new PR.

This implements the ability to save and load profiles of automatic dynamic decisions, so on subsequent runs we can directly make something automatically dynamic. Unlike the previous implementation, this cache is never enabled by default; instead, you have to specify a "job id" that says it's OK to share results. We will be able to automatically populate this id for internal MAST jobs but for generic OSS users you will have to explicitly opt into it.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Differential Revision: [D65065497](https://our.internmc.facebook.com/intern/diff/D65065497)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139001
Approved by: https://github.com/oulgen
This commit is contained in:
Edward Z. Yang 2024-11-01 21:06:34 -07:00 committed by PyTorch MergeBot
parent 55038aa661
commit f6be44c74e
14 changed files with 662 additions and 28 deletions

View File

@ -0,0 +1,9 @@
.. currentmodule:: torch.compiler.config
torch.compiler.config
=====================
.. automodule:: torch.compiler.config
.. autodata:: torch.compiler.config.job_id

View File

@ -85,6 +85,7 @@ Read More
torch.compiler_get_started
torch.compiler_api
torch.compiler.config
torch.compiler_fine_grain_apis
torch.compiler_aot_inductor
torch.compiler_inductor_profiling

126
test/dynamo/test_pgo.py Normal file
View File

@ -0,0 +1,126 @@
# Owner(s): ["module: dynamo"]
import contextlib
import importlib
import os
import sys
import torch._dynamo.config
import torch._dynamo.test_case
import torch.compiler.config
from torch._dynamo.testing import CompileCounter
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache
# LOL. https://github.com/pytorch/pytorch/issues/139252
spec = importlib.util.spec_from_file_location(
"mock_cache", os.path.join(os.path.dirname(__file__), "../inductor/mock_cache.py")
)
mock_cache = importlib.util.module_from_spec(spec)
sys.modules["mock_cache"] = mock_cache
spec.loader.exec_module(mock_cache)
class PgoTest(torch._dynamo.test_case.TestCase):
def setUp(self):
super().setUp()
self._test_stack = contextlib.ExitStack()
self._test_stack.enter_context(torch.compiler.config.patch(job_id=self.id()))
self._test_stack.enter_context(
torch._dynamo.config.patch(automatic_dynamic_local_pgo=True)
)
if os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1":
self._test_stack.enter_context(fresh_inductor_cache())
mock_cache.PatchCaches.setUp()
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
self._test_stack.close()
mock_cache.PatchCaches.tearDown()
def reset(self):
torch._dynamo.reset()
clear_inductor_caches()
def test_basic(self):
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def f(x):
return x * 2
f(torch.randn(2, 3))
f(torch.randn(2, 4))
self.assertEqual(cnts.frame_count, 2)
self.reset()
cnts.clear()
f(torch.randn(2, 5))
f(torch.randn(2, 6))
self.assertEqual(cnts.frame_count, 1)
def test_distinct_compile_id(self):
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def f(x):
return x * 2
with torch.compiler.config.patch(job_id="foo"):
f(torch.randn(2, 3))
f(torch.randn(2, 4))
self.assertEqual(cnts.frame_count, 2)
self.reset()
cnts.clear()
with torch.compiler.config.patch(job_id="bar"):
f(torch.randn(2, 5))
f(torch.randn(2, 6))
self.assertEqual(cnts.frame_count, 2)
torch._dynamo.reset()
clear_inductor_caches()
cnts.clear()
with torch.compiler.config.patch(job_id="foo"):
f(torch.randn(2, 7))
f(torch.randn(2, 8))
self.assertEqual(cnts.frame_count, 1)
# TODO: to test local need to ensure the local filesystem gets cleared out
@torch._dynamo.config.patch(
automatic_dynamic_remote_pgo=True, automatic_dynamic_local_pgo=False
)
def test_remote_basic(self):
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def f(x):
return x * 2
with mock_cache.PatchCaches():
f(torch.randn(2, 3))
f(torch.randn(2, 4))
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(
mock_cache.global_stats.dynamo_pgo, mock_cache.Stats(2, 0, 1)
)
self.reset()
cnts.clear()
f(torch.randn(2, 5))
f(torch.randn(2, 6))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(
mock_cache.global_stats.dynamo_pgo, mock_cache.Stats(2, 1, 1)
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -80,6 +80,7 @@ class _GlobalStats(threading.local):
self.fx_graph = _GlobalItemStats()
self.triton = _GlobalItemStats()
self.aot_autograd = _GlobalItemStats()
self.dynamo_pgo = _GlobalItemStats()
def reset(self) -> None:
self.autotune_local.reset()
@ -88,6 +89,7 @@ class _GlobalStats(threading.local):
self.fx_graph.reset()
self.triton.reset()
self.aot_autograd.reset()
self.dynamo_pgo.reset()
def get_stat(self, name: str) -> _GlobalItemStats:
return getattr(self, name)
@ -100,6 +102,7 @@ class _GlobalStats(threading.local):
("fx_graph", self.fx_graph),
("triton", self.triton),
("aot_autograd", self.aot_autograd),
("dynamo_pgo", self.dynamo_pgo),
)
print("Cache Stats:", file=sys.stderr)
@ -215,6 +218,12 @@ class PatchCaches(contextlib.AbstractContextManager):
)
self._stack.enter_context(ctx)
ctx = patch(
"torch._inductor.remote_cache.RemoteDynamoPGOCache.backend_override_cls",
MockBackend.with_name("dynamo_pgo"),
)
self._stack.enter_context(ctx)
if config.is_fbcode():
ctx = patch(
"torch._inductor.fb.remote_cache.FbRemoteAutotuneCache.backend_override_cls",
@ -246,6 +255,12 @@ class PatchCaches(contextlib.AbstractContextManager):
)
self._stack.enter_context(ctx)
ctx = patch(
"torch._inductor.fb.remote_cache.FbRemoteDynamoPGOCache.backend_override_cls",
MockBackend.with_name("dynamo_pgo"),
)
self._stack.enter_context(ctx)
return self
def __exit__(

View File

@ -33,7 +33,7 @@ from .eval_frame import (
)
from .external_utils import is_compiling
from .mutation_guard import GenerationTracker
from .pgo import CODE_STATE
from .pgo import reset_code_state
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
@ -79,11 +79,24 @@ if torch.manual_seed is torch.random.manual_seed:
def reset() -> None:
"""Clear all compile caches and restore initial state"""
"""
Clear all compile caches and restore initial state. This function is intended
to reset Dynamo's state *as if* you had started a fresh process invocation, which
makes it good for testing scenarios where you want to behave as if you started
a new process. It does NOT affect any file system caches.
NB: this does NOT reset logging state. Don't use this to test logging
initialization/reinitialization.
"""
# TODO: https://github.com/pytorch/pytorch/issues/139200
import logging
log = logging.getLogger(__name__)
log.info("torch._dynamo.reset")
with convert_frame.compile_lock:
reset_code_caches()
convert_frame.input_codes.clear()
CODE_STATE.clear()
reset_code_state()
convert_frame.output_codes.clear()
orig_code_map.clear()
guard_failures.clear()
@ -102,9 +115,19 @@ def reset() -> None:
def reset_code_caches() -> None:
"""
Clears in-memory code cache, which is what stores compiled products. This
resets less state than :func:`reset` and is mostly only used for testing
purposes.
"""
# TODO: https://github.com/pytorch/pytorch/issues/139200
import logging
log = logging.getLogger(__name__)
log.info("torch._dynamo.reset_code_caches")
"""Clear compile caches that are keyed by code objects"""
with convert_frame.compile_lock:
CODE_STATE.clear()
reset_code_state()
for weak_code in (
convert_frame.input_codes.seen + convert_frame.output_codes.seen
):

View File

@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Unio
import torch
from torch._environment import is_fbcode
from torch.utils._config_module import get_tristate_env, install_config_module
# to configure logging for dynamo, aot, and inductor
@ -485,6 +486,36 @@ compiled_autograd_kwargs_override: Dict[str, Any] = {}
# NCCL timeout.
enable_compiler_collectives = os.environ.get("TORCH_COMPILER_COLLECTIVES", "0") == "1"
# Enables a local, filesystem "profile" which can be used for automatic
# dynamic decisions, analogous to profile-guided optimization. This config
# ONLY has an effect if torch.compiler.config.workflow_id is specified,
# which specifies the name of the profile we will save/load.
#
# The idea is that if we observe that a particular input is dynamic over
# multiple iterations on one run, we can save a profile with this information
# so the next time we run we can just make it dynamic the first time around,
# skipping an unnecessary static compilation. The profile can be soundly
# stale, if it is wrong, it just means we may make more things dynamic than
# was actually necessary (NB: this /can/ cause a failure if making something
# dynamic causes the compiler to stop working because you tickled a latent
# bug.)
#
# The profile is ONLY guaranteed to work if the user source code is 100%
# unchanged. Applying the profile if there are user code changes is only
# best effort otherwise. In particular, we identify particular code objects
# by filename, line number and name of their function, so adding/removing newlines
# will typically cause cache misses. We continuously update the profile,
# so if we only discover something is dynamic on the second run, we will update
# the profile for subsequent runs.
automatic_dynamic_local_pgo: bool = (
os.environ.get("TORCH_DYNAMO_AUTOMATIC_DYNAMIC_LOCAL_PGO", "0") == "1"
)
# Like above, but using remote cache
automatic_dynamic_remote_pgo: Optional[bool] = get_tristate_env(
"TORCH_DYNAMO_AUTOMATIC_DYNAMIC_REMOTE_PGO"
)
# HACK: this is for testing custom ops profiling only
_custom_ops_profile: Optional[Any] = None
@ -495,7 +526,4 @@ if TYPE_CHECKING:
...
from torch.utils._config_module import install_config_module
install_config_module(sys.modules[__name__])

View File

@ -93,6 +93,7 @@ from .guards import (
GuardedCode,
)
from .hooks import Hooks
from .pgo import put_code_state
from .replay_record import ExecutionRecord
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
from .symbolic_convert import (
@ -1019,6 +1020,8 @@ def _compile(
f"{type(e).__qualname__}: {str(e)}"
).with_traceback(e.__traceback__) from None
finally:
put_code_state()
if tracer:
tracer.output.local_scope = {}

View File

@ -1,25 +1,42 @@
from __future__ import annotations
import base64
import copy
import dataclasses
import enum
import json
import logging
import os
import pickle
import time
from collections import defaultdict
from typing import DefaultDict, Optional, Tuple, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Self
from torch._dynamo.utils import get_chromium_event_logger
import torch._dynamo.config
import torch._utils_internal
import torch.compiler.config
import torch.distributed as dist
from torch._dynamo.utils import dynamo_timed, get_chromium_event_logger, warn_once
from torch._environment import is_fbcode
from torch._logging._internal import trace_structured_artifact
if TYPE_CHECKING:
import types
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._inductor.remote_cache import JsonDataTy, RemoteCache
class ReservedWorkflowIdUserError(ValueError):
pass
log = logging.getLogger(__name__)
LOCK_TIMEOUT = 10
# How does in memory representation work? Concretely, this module is
# responsible for holding GLOBAL state representing the state it holds, no
# other copies permitted. So we retire frame_state entirely and store it
@ -29,6 +46,48 @@ log = logging.getLogger(__name__)
# don't mind leaking it.
# How exactly did we design the cache key? Here are some of the questions:
#
# - JOB_ID: Do we have a unique identifier for the "training run" (such that
# it stays the same if we're running the same code, and changes if we're
# running something different).
#
# - RANK: Are we sharing the cache across ranks, or does each rank get
# an individual cache?
#
# We choose to require job_id for PGO cache. This is to prevent
# situations where unrelated invocations of PyTorch unpredictably cause
# changes to each other's behavior. With a job_id, at least you know there
# is some "state" associated with it. (State dict might be another way to
# tell if a run is related or not.) You can opt-in to YOLO everything
# aliases everything by passing a shared job_id for all your invocations.
#
# We choose to NOT share PGO cache across ranks. With no RANK_SHARING, there
# is never contention between runs, so we can leisurely update a bundle with
# information we need. Because we are grouped by job_id, we can have a single
# consolidated bundle for everything (or not; maybe worry about O(n^2) IO if
# we updated every compile--let's just instrument this.) Can even take a
# filelock for extra safety (expect no contention); expect 50ns overhead from
# uncontended filelock.
#
# If we did share ranks, everyone is storming to modify the same cache files.
# We can do this by having folks atomic write to a CAS-store and then having
# readers do on-the-fly merging (this can be implemented in remote using
# prefix iteration). As an optional optimization, one rank can be elected to
# handling bundling post facto (ideally, this is done async, after quiescence,
# without compiler collective need to wait for everyone to finish writing
# their bits.) Not sure how you can avoid a listdir because if some rank shows
# up with some new entries we need to pull them in ASAP (unless you want to
# delay bundling).
#
# But compiler collectives fill a similar niche: compilers chat with each
# other so rank 0 has collected everything. So elect rank 0 only to write the
# bundle. Don't even need CAS-store atomic write; just one rank writing an
# updating bundles. The point is that use compiler collectives to share
# profiles across ranks, but use the PGO cache to persist profiles per rank
# across attempts. No need to have one mechanism to do everything.
@dataclasses.dataclass(frozen=True)
class CodeId:
filename: str
@ -47,7 +106,8 @@ class CodeState:
)
CODE_STATE: DefaultDict[CodeId, CodeState] = defaultdict(CodeState)
_INIT_CODE_STATE: Optional[DefaultDict[CodeId, CodeState]] = None
_CODE_STATE: Optional[DefaultDict[CodeId, CodeState]] = None
@dataclasses.dataclass(frozen=True)
@ -195,7 +255,7 @@ def update_automatic_dynamic(
is_unspecialized_nn_module: bool = False,
) -> FrameStateSizeEntry:
code_id = CodeId.make(tx.f_code)
frame_state = CODE_STATE[code_id]
frame_state = get_code_state()[code_id]
is_update = name in frame_state.automatic_dynamic
mut_entry = frame_state.automatic_dynamic[name]
old_entry = copy.copy(mut_entry)
@ -332,3 +392,286 @@ def process_automatic_dynamic(
)
assert res is not None
return res
def get_cache_key() -> Optional[str]:
# TODO: info versions of these logs that log only once
if torch._inductor.config.force_disable_caches:
warn_once(
"dynamo_pgo force disabled by torch._inductor.config.force_disable_caches"
)
return None
# NB: We always use global rank for keys, even though they are overkill
# for local only cache
rank = None
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
# NB: We namespace the cache keys so that only user-specified job id
# can alias with each other.
if (r := torch.compiler.config.job_id) is not None:
if r.startswith("mast:"):
raise ReservedWorkflowIdUserError(
"torch.compiler.config.job_id with prefix 'mast:' is reserved for "
"automatically generated job id associated with a specific MAST job "
"name and version."
)
return f"{r}:{rank}"
if (name_version := torch._utils_internal.get_mast_job_name_version()) is not None:
mast_job_name, mast_job_version = name_version
return f"mast:{mast_job_name}:{mast_job_version}:{rank}"
return None
# This solely controls local PGO
def code_state_path(cache_key: str) -> Optional[str]:
if not torch._dynamo.config.automatic_dynamic_local_pgo:
log.debug("automatic_dynamic_local_pgo not enabled")
return None
from torch._inductor.runtime.runtime_utils import cache_dir
return os.path.join(cache_dir(), "dynamo", f"code_state_{cache_key}.pkl")
def should_use_remote_dynamo_pgo_cache() -> bool:
if torch._inductor.config.force_disable_caches:
return False
if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None:
return r
if not is_fbcode():
return False
if torch._utils_internal.is_fb_unit_test():
return False
try:
from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
except ModuleNotFoundError:
return False
return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
"pytorch/remote_cache:dynamo_pgo_version"
)
def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
from torch._inductor.remote_cache import create_cache
if not should_use_remote_dynamo_pgo_cache():
return None
return create_cache(
"dynamo-pgo",
is_fbcode(),
"FbRemoteDynamoPGOCache",
"RemoteDynamoPGOCache",
)
# TODO: this dump format sucks but apparently it's very difficult to json.dumps
# while not indenting inner lists SIGH
def _key_asdict(x: object) -> object:
if isinstance(x, CodeId):
return f"{x.filename}:{x.firstlineno}:{x.name}"
else:
return x
def _asdict(x: object) -> object:
if isinstance(x, (dict, defaultdict)):
return {_key_asdict(k): _asdict(v) for k, v in x.items()}
elif isinstance(x, (list, tuple)):
return [_asdict(v) for v in x]
elif dataclasses.is_dataclass(x):
return {
field.name: _asdict(getattr(x, field.name))
for field in dataclasses.fields(x)
}
elif x is auto_unset:
return "auto_unset"
elif x is auto_dynamic:
return "auto_dynamic"
else:
return x
def get_code_state() -> DefaultDict[CodeId, CodeState]:
global _CODE_STATE, _INIT_CODE_STATE
if _CODE_STATE is not None:
return _CODE_STATE
chromium_log = get_chromium_event_logger()
# Initialize it (even if we don't look up profile)
_CODE_STATE = defaultdict(CodeState)
cache_key = get_cache_key()
if cache_key is None:
return _CODE_STATE
def hit(ty: str) -> DefaultDict[CodeId, CodeState]:
global _INIT_CODE_STATE
assert isinstance(_CODE_STATE, defaultdict)
log.info("get_code_state %s hit %s, %d entries", path, ty, len(_CODE_STATE))
trace_structured_artifact(
f"get_{ty}_code_state",
"string",
lambda: json.dumps(_asdict(_CODE_STATE), indent=1),
)
_INIT_CODE_STATE = copy.deepcopy(_CODE_STATE)
return _CODE_STATE
# Attempt local
path = code_state_path(cache_key)
if path is not None and os.path.exists(path):
with dynamo_timed(
name := "pgo.get_local_code_state", log_pt2_compile_event=True
):
chromium_log.add_event_data(name, cache_key=cache_key)
# Read lock not necessary as we always write atomically write to
# the actual location
with open(path, "rb") as f:
try:
_CODE_STATE = pickle.load(f)
except Exception:
log.warning(
"get_code_state failed while reading %s", path, exc_info=True
)
else:
return hit("local")
# Attempt remote
remote_cache = get_remote_cache()
if remote_cache is not None:
with dynamo_timed(
name := "pgo.get_remote_code_state", log_pt2_compile_event=True
):
chromium_log.add_event_data(name, cache_key=cache_key)
# TODO: I don't really understand why there's a JSON container format
try:
cache_data = remote_cache.get(cache_key)
except Exception:
log.warning(
"get_code_state failed remote read on %s", cache_key, exc_info=True
)
else:
if cache_data is not None:
try:
assert isinstance(cache_data, dict)
data = cache_data["data"]
assert isinstance(data, str)
payload = base64.b64decode(data)
_CODE_STATE = pickle.loads(payload)
except Exception:
log.warning(
"get_code_state failed parsing remote result on %s",
cache_key,
exc_info=True,
)
else:
return hit("remote")
else:
log.info("get_code_state remote miss on %s", cache_key)
log.info("get_code_state using default")
assert _CODE_STATE is not None
return _CODE_STATE
def put_code_state() -> None:
if _CODE_STATE is None:
log.info("put_code_state: never initialized, will not write")
return
if _CODE_STATE == _INIT_CODE_STATE:
log.info("put_code_state: no change, skipping")
return
cache_key = get_cache_key()
if cache_key is None:
log.info("put_code_state: no cache key, skipping")
return
put_local_code_state(cache_key)
put_remote_code_state(cache_key)
def put_local_code_state(cache_key: str) -> None:
with dynamo_timed(name := "pgo.put_local_code_state", log_pt2_compile_event=True):
chromium_log = get_chromium_event_logger()
chromium_log.add_event_data(name, cache_key=cache_key)
assert _CODE_STATE is not None
path = code_state_path(cache_key)
if path is None:
log.info("put_code_state: local cache disabled")
return
# If the user isn't misusing our API, we should have exclusive access to
# this directory. But it's not too hard
tmp_path = path + ".tmp"
lock_path = path + ".lock"
# We /mostly/ don't need the lock but the tmp file could be clobbered
# TODO: use a safe tempfile create to eliminate lock
from filelock import FileLock
os.makedirs(os.path.dirname(path), exist_ok=True)
with FileLock(lock_path, timeout=LOCK_TIMEOUT):
with open(tmp_path, "wb") as f:
pickle.dump(_CODE_STATE, f)
os.rename(tmp_path, path)
log.info(
"put_code_state: wrote local %s, %d entries", path, len(_CODE_STATE)
)
trace_structured_artifact(
"put_local_code_state",
"string",
lambda: json.dumps(_asdict(_CODE_STATE), indent=1),
)
def put_remote_code_state(cache_key: str) -> None:
with dynamo_timed(name := "pgo.put_remote_code_state", log_pt2_compile_event=True):
chromium_log = get_chromium_event_logger()
chromium_log.add_event_data(name, cache_key=cache_key)
assert _CODE_STATE is not None
remote_cache = get_remote_cache()
if remote_cache is None:
log.info("put_code_state: remote cache disabled")
return
content = pickle.dumps(_CODE_STATE)
cache_data: JsonDataTy = {
"data": base64.b64encode(content).decode("ascii"),
}
remote_cache.put(cache_key, cache_data)
log.info(
"put_code_state: wrote remote %s, %d entries", cache_key, len(_CODE_STATE)
)
# TODO: don't log this multiple times
trace_structured_artifact(
"put_remote_code_state",
"string",
lambda: json.dumps(_asdict(_CODE_STATE), indent=1),
)
# NB: this does NOT reset the cached code state on disk
def reset_code_state() -> None:
global _CODE_STATE, _INIT_CODE_STATE
_CODE_STATE = None
_INIT_CODE_STATE = None

View File

@ -5,31 +5,23 @@ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
import torch
import torch._inductor.custom_graph_pass
from torch._environment import is_fbcode
def _get_tristate_env(name: str) -> Optional[bool]:
value = os.environ.get(name)
if value == "1":
return True
if value == "0":
return False
return None
from torch.utils._config_module import get_tristate_env, install_config_module
def fx_graph_remote_cache_default() -> Optional[bool]:
return _get_tristate_env("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE")
return get_tristate_env("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE")
def autotune_remote_cache_default() -> Optional[bool]:
return _get_tristate_env("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE")
return get_tristate_env("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE")
def bundled_autotune_remote_cache_default() -> Optional[bool]:
return _get_tristate_env("TORCHINDUCTOR_BUNDLED_AUTOTUNE_REMOTE_CACHE")
return get_tristate_env("TORCHINDUCTOR_BUNDLED_AUTOTUNE_REMOTE_CACHE")
def bundle_triton_into_fx_graph_cache_default() -> Optional[bool]:
return _get_tristate_env("TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE")
return get_tristate_env("TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE")
# Enable auto_functionalized_v2 (enabled by default)
@ -1326,8 +1318,6 @@ class test_configs:
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
from torch.utils._config_module import install_config_module
# adds patch, save_config, etc
install_config_module(sys.modules[__name__])

View File

@ -306,6 +306,10 @@ class RemoteAOTAutogradCache(RedisRemoteCache):
pass
class RemoteDynamoPGOCache(RedisRemoteCache):
pass
def create_cache(
key: str,
is_fbcode: bool,

View File

@ -1130,6 +1130,21 @@ def get_structured_logging_overhead() -> Optional[float]:
return None
def trace_structured_artifact(
name: str, # this will go in metadata
encoding: str,
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
) -> None:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": name,
"encoding": encoding,
},
payload_fn=payload_fn,
)
def trace_structured(
name: str,
# NB: metadata expected to be dict so adding more info is forward compatible
@ -1140,7 +1155,7 @@ def trace_structured(
suppress_context: bool = False,
expect_trace_id: bool = True, # Whether or not we expect to have a current trace id
record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
):
) -> None:
"""
metadata is an arbitrary JSON compatible struct, but it's expected to not be
too long (e.g., less than 1MB)

View File

@ -4,7 +4,7 @@ import logging
import os
import sys
import tempfile
from typing import Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from typing_extensions import ParamSpec
import torch
@ -344,6 +344,10 @@ def max_clock_rate():
return 1100
def get_mast_job_name_version() -> Optional[Tuple[str, int]]:
return None
TEST_MASTER_ADDR = "127.0.0.1"
TEST_MASTER_PORT = 29500
# USE_GLOBAL_DEPS controls whether __init__.py tries to load

62
torch/compiler/config.py Normal file
View File

@ -0,0 +1,62 @@
"""
This is the top-level configuration module for the compiler, containing
cross-cutting configuration options that affect all parts of the compiler
stack.
You may also be interested in the per-component configuration modules, which
contain configuration options that affect only a specific part of the compiler:
* :mod:`torch._dynamo.config`
* :mod:`torch._inductor.config`
* :mod:`torch._functorch.config`
* :mod:`torch.fx.experimental.config`
"""
import os
import sys
from typing import Optional
__all__ = [
"job_id",
]
# NB: Docblocks go UNDER variable definitions! Use spacing to make the
# grouping clear.
# FB-internal note: you do NOT have to specify this explicitly specify this if
# you run on MAST, we will automatically default this to
# mast:MAST_JOB_NAME:MAST_JOB_VERSION.
job_id: Optional[str] = os.environ.get("TORCH_COMPILE_JOB_ID", None)
"""
Semantically, this should be an identifier that uniquely identifies, e.g., a
training job. You might have multiple attempts of the same job, e.g., if it was
preempted or needed to be restarted, but each attempt should be running
substantially the same workload with the same distributed topology. You can
set this by environment variable with :envvar:`TORCH_COMPILE_JOB_ID`.
Operationally, this controls the effect of profile-guided optimization related
persistent state. PGO state can affect how we perform compilation across
multiple invocations of PyTorch, e.g., the first time you run your program we
may compile twice as we discover what inputs are dynamic, and then PGO will
save this state so subsequent invocations only need to compile once, because
they remember it is dynamic. This profile information, however, is sensitive
to what workload you are running, so we require you to tell us that two jobs
are *related* (i.e., are the same workload) before we are willing to reuse
this information. Notably, PGO does nothing (even if explicitly enabled)
unless a valid ``job_id`` is available. In some situations, PyTorch can
configured to automatically compute a ``job_id`` based on the environment it
is running in.
Profiles are always collected on a per rank basis, so different ranks may have
different profiles. If you know your workload is truly SPMD, you can run with
:data:`torch._dynamo.config.enable_compiler_collectives` to ensure nodes get
consistent profiles across all ranks.
"""
from torch.utils._config_module import install_config_module
install_config_module(sys.modules[__name__])

View File

@ -3,6 +3,7 @@ import copy
import hashlib
import inspect
import io
import os
import pickle
import tokenize
import unittest
@ -60,7 +61,8 @@ def install_config_module(module: ModuleType) -> None:
"""
class ConfigModuleInstance(ConfigModule):
_bypass_keys = set({"_is_dirty", "_hash_digest"})
# __annotations__ is written to by Sphinx autodoc
_bypass_keys = set({"_is_dirty", "_hash_digest", "__annotations__"})
def visit(
source: Union[ModuleType, type],
@ -496,3 +498,12 @@ def patch_object(obj: object, name: str, value: object) -> object:
if isinstance(obj, ConfigModule):
return obj.patch(name, value)
return mock.patch.object(obj, name, value)
def get_tristate_env(name: str) -> Optional[bool]:
value = os.environ.get(name)
if value == "1":
return True
if value == "0":
return False
return None