mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
TL;DR: Cuts vLLM cudagraph collection from 80s -> 24s
Stop garbage collecting by default on every cudagraph recording. The old behavior can be re-enabled by setting `TORCH_CUDAGRAPH_GC=1` or the config `force_cudagraph_gc`.
We were previously garbage collecting at the beginning of each cudagraph
capture. vLLM collects 5427 graphs and most of those garbage collections weren't
actually collecting any memory (CPU or GPU). This changes it to not collect more
than every 10s so if we're capturing in a loop we don't burn all our cycles
looking for garbage.
(These number have a lot of variance from run to run but give the correct
general scale)
```
| calls | total | synchronize | gcs | collect | empty cache | sys freed | cuda freed |
-------+-------+-------+-------------+------+---------+-------------+-----------+------------+
before | 5427 | 78s | 1.48s | 5427 | 53.22s | 1.21s | 145855 | 1539309568 |
-------+-------+-------+-------------+------+---------+-------------+-----------+------------+
after | 5427 | 24s | 0s | 3 | 1.53s | 0.84s | 592 | 1539309568 |
-------+-------+-------+-------------+------+---------+-------------+-----------+------------+
```
total - this is the total time reported by vLLM's "Graph capturing finished" log.
The rest of these are measured in torch.cuda.graphs.graph.__enter__():
calls - number of times torch.cuda.graphs.graph.__enter__ was called
synchronize - this is the duration taken by the cuda.synchronize call
gcs - number of times gc.collect was called
collect - this is the duration taken by the gc.collect call
empty cache - this is the duration taken by the torch.cuda.empty_cache call
sys freed - the number of bytes reported freed by gc.collect
cuda freed - the number of bytes reported freed by torch.cuda.memory_reserved
So it seems like the heavy lifting is done by torch.cuda.empty_cache() which is
fairly quick.
Cudagraph results from the TorchInductor Performance DashBoard (this is from the original version using the GC clock so the real results will be slightly better than this):
<img width="1494" height="382" alt="image" src="https://github.com/user-attachments/assets/69b705ef-47ce-4b6e-9733-1ec941cad93d" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158193
Approved by: https://github.com/ngimel
100 lines
3.8 KiB
Python
100 lines
3.8 KiB
Python
"""
|
|
This is the top-level configuration module for the compiler, containing
|
|
cross-cutting configuration options that affect all parts of the compiler
|
|
stack.
|
|
|
|
You may also be interested in the per-component configuration modules, which
|
|
contain configuration options that affect only a specific part of the compiler:
|
|
|
|
* :mod:`torch._dynamo.config`
|
|
* :mod:`torch._inductor.config`
|
|
* :mod:`torch._functorch.config`
|
|
* :mod:`torch.fx.experimental.config`
|
|
"""
|
|
|
|
import sys
|
|
from typing import Optional
|
|
|
|
from torch.utils._config_module import Config, install_config_module
|
|
|
|
|
|
__all__ = [
|
|
"job_id",
|
|
]
|
|
|
|
|
|
# NB: Docblocks go UNDER variable definitions! Use spacing to make the
|
|
# grouping clear.
|
|
|
|
# FB-internal note: you do NOT have to specify this explicitly specify this if
|
|
# you run on MAST, we will automatically default this to
|
|
# mast:MAST_JOB_NAME:MAST_JOB_VERSION.
|
|
job_id: Optional[str] = Config(
|
|
env_name_default=["TORCH_COMPILE_JOB_ID", "TORCH_COMPILE_STICKY_PGO_KEY"],
|
|
default=None,
|
|
)
|
|
"""
|
|
Semantically, this should be an identifier that uniquely identifies, e.g., a
|
|
training job. You might have multiple attempts of the same job, e.g., if it was
|
|
preempted or needed to be restarted, but each attempt should be running
|
|
substantially the same workload with the same distributed topology. You can
|
|
set this by environment variable with :envvar:`TORCH_COMPILE_JOB_ID`.
|
|
|
|
Operationally, this controls the effect of profile-guided optimization related
|
|
persistent state. PGO state can affect how we perform compilation across
|
|
multiple invocations of PyTorch, e.g., the first time you run your program we
|
|
may compile twice as we discover what inputs are dynamic, and then PGO will
|
|
save this state so subsequent invocations only need to compile once, because
|
|
they remember it is dynamic. This profile information, however, is sensitive
|
|
to what workload you are running, so we require you to tell us that two jobs
|
|
are *related* (i.e., are the same workload) before we are willing to reuse
|
|
this information. Notably, PGO does nothing (even if explicitly enabled)
|
|
unless a valid ``job_id`` is available. In some situations, PyTorch can
|
|
configured to automatically compute a ``job_id`` based on the environment it
|
|
is running in.
|
|
|
|
Profiles are always collected on a per rank basis, so different ranks may have
|
|
different profiles. If you know your workload is truly SPMD, you can run with
|
|
:data:`torch._dynamo.config.enable_compiler_collectives` to ensure nodes get
|
|
consistent profiles across all ranks.
|
|
"""
|
|
|
|
|
|
cache_key_tag: str = Config(env_name_default="TORCH_COMPILE_CACHE_KEY_TAG", default="")
|
|
"""
|
|
Tag to be included in the cache key generation for all torch compile caching.
|
|
A common use case for such a tag is to break caches.
|
|
"""
|
|
|
|
dynamic_sources: str = Config(
|
|
env_name_default="TORCH_COMPILE_DYNAMIC_SOURCES", default=""
|
|
)
|
|
"""
|
|
Comma delimited list of sources that should be marked as dynamic. Primarily useful for large
|
|
models with graph breaks where you need intermediate tensors and ints to be marked dynamic.
|
|
|
|
This whitelist is dominant over all other flags dynamic=False, force_nn_module_property_static_shapes
|
|
and force_parameter_static_shapes.
|
|
"""
|
|
|
|
unbacked_sources: str = Config(
|
|
env_name_default="TORCH_COMPILE_UNBACKED_SOURCES", default=""
|
|
)
|
|
"""
|
|
Comma delimited list of sources that should be marked as unbacked. Primarily useful for large
|
|
models with graph breaks where you need intermediate tensors marked unbacked.
|
|
|
|
This whitelist is dominant over all other flags dynamic=False, force_nn_module_property_static_shapes
|
|
and force_parameter_static_shapes.
|
|
"""
|
|
|
|
# force a python GC before recording cudagraphs
|
|
force_cudagraph_gc: bool = Config(env_name_default="TORCH_CUDAGRAPH_GC", default=True)
|
|
"""
|
|
If True (the backward-compatible behavior) then gc.collect() before recording
|
|
any cudagraph.
|
|
"""
|
|
|
|
|
|
install_config_module(sys.modules[__name__])
|