mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Precompile] Various small bugfixes, add CachingPrecompile to torchbench (#158847)
This PR addresses a few small bugfixes needed to make NanoGPT inference work, and also adds a new `--caching-precompile` argument to torchbench. With `--caching-precompile`, after every benchmark we save precompile artifacts to DynamoCache, allowing us to test caching precompile on all existing benchmarks. The following bugfixes are in this PR to make all of this work: - Fix global variables being pruned with DUPLICATE_INPUT guards. DUPLICATE_INPUT guards have additional vars from the second input, which we track with additional_local_vars, but we never tracked additional global variables. This fixes the issue. (See torch/_dynamo/guards.py changes) - Return None from PRecompileContext.serialize() if no new dynamo compiles occurred. There's no reason to save artifacts (i.e. autotuning artifacts, etc) if no dynamo_compile occurred, so we return None early. We may later want to support editing existing dynamo artifacts as a TODO, but that's upcoming. - log `dynamo_start` on CompilePackage.load: This is only needed so that tlparse doesn't ignore TORCH_TRACE logs generated when caching precompile hits. If there are no actual compiles, we never log a "dynamo_start" entry, which makes internal tlparse ignore the TORCH_TRACE file. ## Test Plan After this PR, the following now works: ``` TORCH_LOGS=dynamo tlp python benchmarks/dynamo/torchbench.py --only nanogpt --performance --inference --backend inductor --caching-precompile --warm-start-latency ``` tlparse result (internal): Cold Start (6 seconds): https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpAWe0zD/dedicated_log_torch_trace_vk9nkp4m.log/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 Warm Start (~1 s): https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpAWe0zD/dedicated_log_torch_trace_5l4iwrpm.log/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 The 1 second of warm start here can be improved: the costs here are mostly in starting up workers and triton and initializing CUDA, a lot of which should not be included in the compile time cost in real world scenarios where these are already loaded before training begins. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158847 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
5998cd4eaa
commit
d898d0d437
|
|
@ -3264,6 +3264,12 @@ def parse_args(args=None):
|
|||
instead of deleting it and creating a new one.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--caching-precompile",
|
||||
action="store_true",
|
||||
help="Enables caching precompile, serializing artifacts to DynamoCache between runs",
|
||||
)
|
||||
|
||||
group_latency = parser.add_mutually_exclusive_group()
|
||||
group_latency.add_argument(
|
||||
"--cold-start-latency",
|
||||
|
|
@ -3414,6 +3420,29 @@ def parse_args(args=None):
|
|||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def process_caching_precompile():
|
||||
"""
|
||||
After every process_entry, save precompile artifacts to DynamoCache
|
||||
"""
|
||||
assert torch._dynamo.config.caching_precompile, (
|
||||
"Caching precompile should be enabled with --caching-precompile"
|
||||
)
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
|
||||
# Serialize all callables, clear PrecompileContext
|
||||
# TODO: put this under torch.compiler API once ready
|
||||
serialized = PrecompileContext.serialize()
|
||||
PrecompileContext.clear()
|
||||
if serialized is not None:
|
||||
artifacts, info = serialized
|
||||
print(
|
||||
f"Saving {len(info.precompile_dynamo_artifacts)} Precompile Artifact(s)..."
|
||||
)
|
||||
results = PrecompileContext.deserialize(artifacts)
|
||||
assert results is not None
|
||||
PrecompileContext.populate_caches(results)
|
||||
|
||||
|
||||
def process_entry(rank, runner, original_dir, args):
|
||||
args.rank = rank
|
||||
with maybe_init_distributed(
|
||||
|
|
@ -3422,7 +3451,10 @@ def process_entry(rank, runner, original_dir, args):
|
|||
world_size=args.world_size,
|
||||
port=args.distributed_master_port,
|
||||
):
|
||||
return run(runner, args, original_dir)
|
||||
result = run(runner, args, original_dir)
|
||||
if args.caching_precompile:
|
||||
process_caching_precompile()
|
||||
return result
|
||||
|
||||
|
||||
def maybe_fresh_cache(args):
|
||||
|
|
@ -3458,6 +3490,10 @@ def main(runner, original_dir=None, args=None):
|
|||
)
|
||||
|
||||
with maybe_fresh_cache(args):
|
||||
if args.caching_precompile:
|
||||
os.environ["TORCH_CACHING_PRECOMPILE"] = "1"
|
||||
torch._dynamo.config.caching_precompile = True
|
||||
|
||||
args.init_distributed = args.only and args.multiprocess
|
||||
if args.init_distributed:
|
||||
# NB: Do NOT query device count before CUDA initialization; we're
|
||||
|
|
|
|||
|
|
@ -549,7 +549,7 @@ fake_tensor_disable_inference_mode = True
|
|||
|
||||
# Experimental feature for running automatic caching precompile.
|
||||
# Enables automatic DynamoCache save/load
|
||||
caching_precompile = False
|
||||
caching_precompile = os.environ.get("TORCH_CACHING_PRECOMPILE", "0") == "1"
|
||||
|
||||
# Enables the Compiled Autograd engine to trace autograd calls made under torch.compile().
|
||||
# Note: AOTAutograd will still trace and partition an AOT backward graph local to that
|
||||
|
|
|
|||
|
|
@ -225,6 +225,31 @@ def fx_forward_from_src_skip_result(
|
|||
return result
|
||||
|
||||
|
||||
def log_dynamo_start(code: CodeType, skip: int = 0) -> None:
|
||||
convert_frame_intern = structured.intern_string(__file__)
|
||||
# Initialize the ChromiumEventLogger on start
|
||||
torch._logging.trace_structured(
|
||||
"dynamo_start",
|
||||
lambda: {
|
||||
"stack": list(
|
||||
itertools.takewhile(
|
||||
lambda f: f["filename"] != convert_frame_intern,
|
||||
structured.from_traceback(
|
||||
CapturedTraceback.extract(skip=4 + skip).summary()
|
||||
),
|
||||
)
|
||||
)
|
||||
+ [
|
||||
{
|
||||
"line": code.co_firstlineno,
|
||||
"name": code.co_name,
|
||||
"filename": structured.intern_string(code.co_filename),
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
"""
|
||||
Context manager to:
|
||||
|
|
@ -1135,28 +1160,7 @@ def _compile(
|
|||
# # 2 extra here
|
||||
# torch/_logging/_internal.py:1064 in trace_structured
|
||||
# torch/_dynamo/convert_frame.py:780 in <lambda>
|
||||
convert_frame_intern = structured.intern_string(__file__)
|
||||
# Initialize the ChromiumEventLogger on start
|
||||
torch._logging.trace_structured(
|
||||
"dynamo_start",
|
||||
lambda: {
|
||||
"stack": list(
|
||||
itertools.takewhile(
|
||||
lambda f: f["filename"] != convert_frame_intern,
|
||||
structured.from_traceback(
|
||||
CapturedTraceback.extract(skip=4 + skip).summary()
|
||||
),
|
||||
)
|
||||
)
|
||||
+ [
|
||||
{
|
||||
"line": code.co_firstlineno,
|
||||
"name": code.co_name,
|
||||
"filename": structured.intern_string(code.co_filename),
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
log_dynamo_start(code, skip)
|
||||
start_time_ns = time.time_ns()
|
||||
fail_type: Optional[str] = None
|
||||
fail_reason: Optional[str] = None
|
||||
|
|
@ -1588,9 +1592,10 @@ class CatchErrorsWrapper:
|
|||
|
||||
with compile_lock, _disable_current_modes():
|
||||
# skip=1: skip this frame
|
||||
return self._torchdynamo_orig_backend(
|
||||
result = self._torchdynamo_orig_backend(
|
||||
frame, cache_entry, self.hooks, frame_state, skip=1
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def catch_errors_wrapper(
|
||||
|
|
|
|||
|
|
@ -679,8 +679,7 @@ class _TorchDynamoContext:
|
|||
|
||||
# If self._package is lazily initialized, we should check the dynamo cache now
|
||||
if config.caching_precompile:
|
||||
assert self._package is not None
|
||||
if not self._package.is_initialized():
|
||||
if self._package is not None and not self._package.is_initialized():
|
||||
result = DynamoCache.load(fn)
|
||||
if result is None:
|
||||
# Create a fresh CompilePackage
|
||||
|
|
|
|||
|
|
@ -1969,6 +1969,8 @@ class GuardBuilder(GuardBuilderBase):
|
|||
if self.serialization_mode == "save":
|
||||
if name := get_local_source_name(source_b):
|
||||
self.check_fn_manager.additional_used_local_vars.add(name)
|
||||
if name := get_global_source_name(source_b):
|
||||
self.check_fn_manager.additional_used_global_vars.add(name)
|
||||
|
||||
ref_a = self.arg_ref(guard)
|
||||
ref_b = self.arg_ref(source_b.name())
|
||||
|
|
@ -2848,6 +2850,7 @@ class CheckFunctionManager:
|
|||
self.guards_serialization_mode = guards_serialization_mode
|
||||
self.used_builtin_vars: OrderedSet[str] = OrderedSet()
|
||||
self.additional_used_local_vars: OrderedSet[str] = OrderedSet()
|
||||
self.additional_used_global_vars: OrderedSet[str] = OrderedSet()
|
||||
if runtime_global_scope:
|
||||
assert self.guards_serialization_mode == "load"
|
||||
self.runtime_global_scope = runtime_global_scope
|
||||
|
|
@ -3038,7 +3041,7 @@ class CheckFunctionManager:
|
|||
global_scope_state = {
|
||||
k: v
|
||||
for k, v in output_graph_guards_state.global_scope.items()
|
||||
if k in used_global_vars
|
||||
if k in used_global_vars or k in self.additional_used_global_vars
|
||||
}
|
||||
global_scope_state[builtins_dict_name] = {
|
||||
k: v
|
||||
|
|
|
|||
|
|
@ -380,7 +380,7 @@ class CompilePackage:
|
|||
3. Install the precompiled cache entries to ExtraStates on the code object.
|
||||
"""
|
||||
from torch._C._dynamo.eval_frame import _load_precompile_entry
|
||||
from torch._dynamo.convert_frame import get_compile_id
|
||||
from torch._dynamo.convert_frame import get_compile_id, log_dynamo_start
|
||||
from torch._guards import compile_context, CompileContext
|
||||
|
||||
from .output_graph import get_builtins_dict
|
||||
|
|
@ -394,6 +394,7 @@ class CompilePackage:
|
|||
# collapsed into 0/0, 1/0 on warm.
|
||||
increment_frame()
|
||||
compile_id = get_compile_id(frame_state={})
|
||||
log_dynamo_start(code)
|
||||
with (
|
||||
compile_context(CompileContext(compile_id)),
|
||||
dynamo_timed(
|
||||
|
|
|
|||
|
|
@ -141,6 +141,9 @@ class PrecompileContext(CacheArtifactManager):
|
|||
@classmethod
|
||||
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
|
||||
cls._save_artifacts_by_type()
|
||||
# No need to serialize if there are no new dynamo compiles
|
||||
if "precompile_dynamo" not in cls._new_cache_artifacts:
|
||||
return None
|
||||
return super().serialize()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user