Log AOTAutogradCache state to PT2 Compile Events (#138604)

Same as previous diff for inductor, but for autograd instead

Differential Revision: [D64765199](https://our.internmc.facebook.com/intern/diff/D64765199/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138604
Approved by: https://github.com/oulgen
This commit is contained in:
James Wu 2024-10-25 11:20:01 -07:00 committed by PyTorch MergeBot
parent f1a677cba5
commit eb6c7b93a7
2 changed files with 42 additions and 1 deletions

View File

@ -33,6 +33,7 @@ from torch._inductor.codecache import (
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.utils import should_use_remote_fx_graph_cache
from torch._logging import LazyString
from torch._utils_internal import log_cache_bypass
from .runtime_wrappers import (
AOTDispatchAutograd,
@ -439,11 +440,14 @@ class AOTAutogradCacheEntry:
compiled_fw_func = self.compiled_fw.load(args, fx_config)
compiled_bw_func = None
chromium_log = get_chromium_event_logger()
if self.compiled_bw is not None:
compiled_bw_func = self.compiled_bw.load(args, fx_config)
needs_autograd = True
chromium_log.add_event_data("backend_compile", dispatch_mode="autograd")
else:
needs_autograd = False
chromium_log.add_event_data("backend_compile", dispatch_mode="inference")
# Wrap the forward function in post compile wrappers
compiled_fw_func = AOTDispatchSubclassWrapper(
@ -455,6 +459,11 @@ class AOTAutogradCacheEntry:
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
)
req_subclass_dispatch = self.maybe_subclass_meta is not None
chromium_log.add_event_data(
"backend_compile", requires_subclass_dispatch=req_subclass_dispatch
)
# In autograd case, functionalizedRngWrapper should not modify outs
return_new_outs = not needs_autograd
compiled_fw_func = FunctionalizedRngRuntimeWrapper(
@ -619,6 +628,9 @@ class AOTAutogradCache:
counters["aot_autograd"]["autograd_cache_bypass"] += 1
cache_state = "bypass"
cache_event_time = time.time_ns()
cache_info["cache_bypass_reason"] = str(e)
if remote:
log_cache_bypass("bypass_aot_autograd", str(e))
if config.strict_autograd_cache:
raise e
if compiled_fn is None:
@ -638,6 +650,18 @@ class AOTAutogradCache:
chromium_log.log_instant_event(
f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info
)
chromium_log.add_event_data(
"backend_compile",
cache_state=cache_state,
cache_event_time=cache_event_time,
key=cache_info.get("key"),
components=cache_info.get("components"),
cache_bypass_reason=cache_info.get("cache_bypass_reason"),
remote_cache_enabled=remote,
local_cache_enabled=local,
)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {

View File

@ -15,7 +15,11 @@ from torch import Tensor
from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo import compiled_autograd
from torch._dynamo.utils import dynamo_timed, preserve_rng_state
from torch._dynamo.utils import (
dynamo_timed,
get_chromium_event_logger,
preserve_rng_state,
)
from torch._guards import detect_fake_mode
from torch._inductor.utils import BoxedBool
from torch._subclasses import FakeTensor, FakeTensorMode
@ -581,6 +585,13 @@ def _create_aot_dispatcher_function(
enable_python_dispatcher() if shape_env is not None else nullcontext()
)
def try_record_chromium_data(**kwargs):
# `backend_compile` only exists as an event if we are compiling with dynamo
# In some unit tests we don't use dynamo, so we ignore those cases
chromium_log = get_chromium_event_logger()
if "backend_compile" in chromium_log.get_stack():
chromium_log.add_event_data("backend_compile", **kwargs)
# See NOTE: [Deferring tensor pack/unpack hooks until runtime]
# If any saved tensor hooks are active, we **don't** want to trace them.
# Instead, we'll let them run at runtime, around the custom autograd.Function
@ -634,6 +645,9 @@ def _create_aot_dispatcher_function(
req_subclass_dispatch = requires_subclass_dispatch(
fake_flat_args, fw_metadata
)
try_record_chromium_data(
requires_subclass_dispatch=req_subclass_dispatch
)
output_and_mutation_safe = not any(
x.requires_grad
@ -752,10 +766,13 @@ or otherwise set torch._functorch.config.functionalize_rng_ops = False."""
if aot_config.is_export:
# export uses just the "graph bits", whereas the other
# two dispatchers include some extra work around handling a runtime epilogue
try_record_chromium_data(dispatch_mode="export")
return partial(aot_dispatch_export, needs_autograd=needs_autograd)
elif needs_autograd and not aot_config.pre_dispatch:
try_record_chromium_data(dispatch_mode="autograd")
return aot_dispatch_autograd
else:
try_record_chromium_data(dispatch_mode="inference")
return aot_dispatch_base
compiler_fn = choose_dispatcher(needs_autograd, aot_config)