mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Previously: https://github.com/pytorch/pytorch/pull/138052 but the implementation is done from scratch, so I open a new PR. This implements the ability to save and load profiles of automatic dynamic decisions, so on subsequent runs we can directly make something automatically dynamic. Unlike the previous implementation, this cache is never enabled by default; instead, you have to specify a "job id" that says it's OK to share results. We will be able to automatically populate this id for internal MAST jobs but for generic OSS users you will have to explicitly opt into it. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/139001 Approved by: https://github.com/oulgen
63 lines
2.4 KiB
Python
63 lines
2.4 KiB
Python
"""
|
|
This is the top-level configuration module for the compiler, containing
|
|
cross-cutting configuration options that affect all parts of the compiler
|
|
stack.
|
|
|
|
You may also be interested in the per-component configuration modules, which
|
|
contain configuration options that affect only a specific part of the compiler:
|
|
|
|
* :mod:`torch._dynamo.config`
|
|
* :mod:`torch._inductor.config`
|
|
* :mod:`torch._functorch.config`
|
|
* :mod:`torch.fx.experimental.config`
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
from typing import Optional
|
|
|
|
|
|
__all__ = [
|
|
"job_id",
|
|
]
|
|
|
|
|
|
# NB: Docblocks go UNDER variable definitions! Use spacing to make the
|
|
# grouping clear.
|
|
|
|
# FB-internal note: you do NOT have to specify this explicitly specify this if
|
|
# you run on MAST, we will automatically default this to
|
|
# mast:MAST_JOB_NAME:MAST_JOB_VERSION.
|
|
job_id: Optional[str] = os.environ.get("TORCH_COMPILE_JOB_ID", None)
|
|
"""
|
|
Semantically, this should be an identifier that uniquely identifies, e.g., a
|
|
training job. You might have multiple attempts of the same job, e.g., if it was
|
|
preempted or needed to be restarted, but each attempt should be running
|
|
substantially the same workload with the same distributed topology. You can
|
|
set this by environment variable with :envvar:`TORCH_COMPILE_JOB_ID`.
|
|
|
|
Operationally, this controls the effect of profile-guided optimization related
|
|
persistent state. PGO state can affect how we perform compilation across
|
|
multiple invocations of PyTorch, e.g., the first time you run your program we
|
|
may compile twice as we discover what inputs are dynamic, and then PGO will
|
|
save this state so subsequent invocations only need to compile once, because
|
|
they remember it is dynamic. This profile information, however, is sensitive
|
|
to what workload you are running, so we require you to tell us that two jobs
|
|
are *related* (i.e., are the same workload) before we are willing to reuse
|
|
this information. Notably, PGO does nothing (even if explicitly enabled)
|
|
unless a valid ``job_id`` is available. In some situations, PyTorch can
|
|
configured to automatically compute a ``job_id`` based on the environment it
|
|
is running in.
|
|
|
|
Profiles are always collected on a per rank basis, so different ranks may have
|
|
different profiles. If you know your workload is truly SPMD, you can run with
|
|
:data:`torch._dynamo.config.enable_compiler_collectives` to ensure nodes get
|
|
consistent profiles across all ranks.
|
|
"""
|
|
|
|
|
|
from torch.utils._config_module import install_config_module
|
|
|
|
|
|
install_config_module(sys.modules[__name__])
|