mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Log AOTAutogradCache state to PT2 Compile Events (#138604)
Same as previous diff for inductor, but for autograd instead Differential Revision: [D64765199](https://our.internmc.facebook.com/intern/diff/D64765199/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/138604 Approved by: https://github.com/oulgen
This commit is contained in:
parent
f1a677cba5
commit
eb6c7b93a7
|
|
@ -33,6 +33,7 @@ from torch._inductor.codecache import (
|
|||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.utils import should_use_remote_fx_graph_cache
|
||||
from torch._logging import LazyString
|
||||
from torch._utils_internal import log_cache_bypass
|
||||
|
||||
from .runtime_wrappers import (
|
||||
AOTDispatchAutograd,
|
||||
|
|
@ -439,11 +440,14 @@ class AOTAutogradCacheEntry:
|
|||
|
||||
compiled_fw_func = self.compiled_fw.load(args, fx_config)
|
||||
compiled_bw_func = None
|
||||
chromium_log = get_chromium_event_logger()
|
||||
if self.compiled_bw is not None:
|
||||
compiled_bw_func = self.compiled_bw.load(args, fx_config)
|
||||
needs_autograd = True
|
||||
chromium_log.add_event_data("backend_compile", dispatch_mode="autograd")
|
||||
else:
|
||||
needs_autograd = False
|
||||
chromium_log.add_event_data("backend_compile", dispatch_mode="inference")
|
||||
|
||||
# Wrap the forward function in post compile wrappers
|
||||
compiled_fw_func = AOTDispatchSubclassWrapper(
|
||||
|
|
@ -455,6 +459,11 @@ class AOTAutogradCacheEntry:
|
|||
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
|
||||
)
|
||||
|
||||
req_subclass_dispatch = self.maybe_subclass_meta is not None
|
||||
chromium_log.add_event_data(
|
||||
"backend_compile", requires_subclass_dispatch=req_subclass_dispatch
|
||||
)
|
||||
|
||||
# In autograd case, functionalizedRngWrapper should not modify outs
|
||||
return_new_outs = not needs_autograd
|
||||
compiled_fw_func = FunctionalizedRngRuntimeWrapper(
|
||||
|
|
@ -619,6 +628,9 @@ class AOTAutogradCache:
|
|||
counters["aot_autograd"]["autograd_cache_bypass"] += 1
|
||||
cache_state = "bypass"
|
||||
cache_event_time = time.time_ns()
|
||||
cache_info["cache_bypass_reason"] = str(e)
|
||||
if remote:
|
||||
log_cache_bypass("bypass_aot_autograd", str(e))
|
||||
if config.strict_autograd_cache:
|
||||
raise e
|
||||
if compiled_fn is None:
|
||||
|
|
@ -638,6 +650,18 @@ class AOTAutogradCache:
|
|||
chromium_log.log_instant_event(
|
||||
f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info
|
||||
)
|
||||
|
||||
chromium_log.add_event_data(
|
||||
"backend_compile",
|
||||
cache_state=cache_state,
|
||||
cache_event_time=cache_event_time,
|
||||
key=cache_info.get("key"),
|
||||
components=cache_info.get("components"),
|
||||
cache_bypass_reason=cache_info.get("cache_bypass_reason"),
|
||||
remote_cache_enabled=remote,
|
||||
local_cache_enabled=local,
|
||||
)
|
||||
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,11 @@ from torch import Tensor
|
|||
from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo import compiled_autograd
|
||||
from torch._dynamo.utils import dynamo_timed, preserve_rng_state
|
||||
from torch._dynamo.utils import (
|
||||
dynamo_timed,
|
||||
get_chromium_event_logger,
|
||||
preserve_rng_state,
|
||||
)
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._inductor.utils import BoxedBool
|
||||
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||
|
|
@ -581,6 +585,13 @@ def _create_aot_dispatcher_function(
|
|||
enable_python_dispatcher() if shape_env is not None else nullcontext()
|
||||
)
|
||||
|
||||
def try_record_chromium_data(**kwargs):
|
||||
# `backend_compile` only exists as an event if we are compiling with dynamo
|
||||
# In some unit tests we don't use dynamo, so we ignore those cases
|
||||
chromium_log = get_chromium_event_logger()
|
||||
if "backend_compile" in chromium_log.get_stack():
|
||||
chromium_log.add_event_data("backend_compile", **kwargs)
|
||||
|
||||
# See NOTE: [Deferring tensor pack/unpack hooks until runtime]
|
||||
# If any saved tensor hooks are active, we **don't** want to trace them.
|
||||
# Instead, we'll let them run at runtime, around the custom autograd.Function
|
||||
|
|
@ -634,6 +645,9 @@ def _create_aot_dispatcher_function(
|
|||
req_subclass_dispatch = requires_subclass_dispatch(
|
||||
fake_flat_args, fw_metadata
|
||||
)
|
||||
try_record_chromium_data(
|
||||
requires_subclass_dispatch=req_subclass_dispatch
|
||||
)
|
||||
|
||||
output_and_mutation_safe = not any(
|
||||
x.requires_grad
|
||||
|
|
@ -752,10 +766,13 @@ or otherwise set torch._functorch.config.functionalize_rng_ops = False."""
|
|||
if aot_config.is_export:
|
||||
# export uses just the "graph bits", whereas the other
|
||||
# two dispatchers include some extra work around handling a runtime epilogue
|
||||
try_record_chromium_data(dispatch_mode="export")
|
||||
return partial(aot_dispatch_export, needs_autograd=needs_autograd)
|
||||
elif needs_autograd and not aot_config.pre_dispatch:
|
||||
try_record_chromium_data(dispatch_mode="autograd")
|
||||
return aot_dispatch_autograd
|
||||
else:
|
||||
try_record_chromium_data(dispatch_mode="inference")
|
||||
return aot_dispatch_base
|
||||
|
||||
compiler_fn = choose_dispatcher(needs_autograd, aot_config)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user