mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Log AOTAutogradCache state to PT2 Compile Events (#138604)
Same as previous diff for inductor, but for autograd instead Differential Revision: [D64765199](https://our.internmc.facebook.com/intern/diff/D64765199/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/138604 Approved by: https://github.com/oulgen
This commit is contained in:
parent
f1a677cba5
commit
eb6c7b93a7
|
|
@ -33,6 +33,7 @@ from torch._inductor.codecache import (
|
||||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||||
from torch._inductor.utils import should_use_remote_fx_graph_cache
|
from torch._inductor.utils import should_use_remote_fx_graph_cache
|
||||||
from torch._logging import LazyString
|
from torch._logging import LazyString
|
||||||
|
from torch._utils_internal import log_cache_bypass
|
||||||
|
|
||||||
from .runtime_wrappers import (
|
from .runtime_wrappers import (
|
||||||
AOTDispatchAutograd,
|
AOTDispatchAutograd,
|
||||||
|
|
@ -439,11 +440,14 @@ class AOTAutogradCacheEntry:
|
||||||
|
|
||||||
compiled_fw_func = self.compiled_fw.load(args, fx_config)
|
compiled_fw_func = self.compiled_fw.load(args, fx_config)
|
||||||
compiled_bw_func = None
|
compiled_bw_func = None
|
||||||
|
chromium_log = get_chromium_event_logger()
|
||||||
if self.compiled_bw is not None:
|
if self.compiled_bw is not None:
|
||||||
compiled_bw_func = self.compiled_bw.load(args, fx_config)
|
compiled_bw_func = self.compiled_bw.load(args, fx_config)
|
||||||
needs_autograd = True
|
needs_autograd = True
|
||||||
|
chromium_log.add_event_data("backend_compile", dispatch_mode="autograd")
|
||||||
else:
|
else:
|
||||||
needs_autograd = False
|
needs_autograd = False
|
||||||
|
chromium_log.add_event_data("backend_compile", dispatch_mode="inference")
|
||||||
|
|
||||||
# Wrap the forward function in post compile wrappers
|
# Wrap the forward function in post compile wrappers
|
||||||
compiled_fw_func = AOTDispatchSubclassWrapper(
|
compiled_fw_func = AOTDispatchSubclassWrapper(
|
||||||
|
|
@ -455,6 +459,11 @@ class AOTAutogradCacheEntry:
|
||||||
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
|
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
|
req_subclass_dispatch = self.maybe_subclass_meta is not None
|
||||||
|
chromium_log.add_event_data(
|
||||||
|
"backend_compile", requires_subclass_dispatch=req_subclass_dispatch
|
||||||
|
)
|
||||||
|
|
||||||
# In autograd case, functionalizedRngWrapper should not modify outs
|
# In autograd case, functionalizedRngWrapper should not modify outs
|
||||||
return_new_outs = not needs_autograd
|
return_new_outs = not needs_autograd
|
||||||
compiled_fw_func = FunctionalizedRngRuntimeWrapper(
|
compiled_fw_func = FunctionalizedRngRuntimeWrapper(
|
||||||
|
|
@ -619,6 +628,9 @@ class AOTAutogradCache:
|
||||||
counters["aot_autograd"]["autograd_cache_bypass"] += 1
|
counters["aot_autograd"]["autograd_cache_bypass"] += 1
|
||||||
cache_state = "bypass"
|
cache_state = "bypass"
|
||||||
cache_event_time = time.time_ns()
|
cache_event_time = time.time_ns()
|
||||||
|
cache_info["cache_bypass_reason"] = str(e)
|
||||||
|
if remote:
|
||||||
|
log_cache_bypass("bypass_aot_autograd", str(e))
|
||||||
if config.strict_autograd_cache:
|
if config.strict_autograd_cache:
|
||||||
raise e
|
raise e
|
||||||
if compiled_fn is None:
|
if compiled_fn is None:
|
||||||
|
|
@ -638,6 +650,18 @@ class AOTAutogradCache:
|
||||||
chromium_log.log_instant_event(
|
chromium_log.log_instant_event(
|
||||||
f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info
|
f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info
|
||||||
)
|
)
|
||||||
|
|
||||||
|
chromium_log.add_event_data(
|
||||||
|
"backend_compile",
|
||||||
|
cache_state=cache_state,
|
||||||
|
cache_event_time=cache_event_time,
|
||||||
|
key=cache_info.get("key"),
|
||||||
|
components=cache_info.get("components"),
|
||||||
|
cache_bypass_reason=cache_info.get("cache_bypass_reason"),
|
||||||
|
remote_cache_enabled=remote,
|
||||||
|
local_cache_enabled=local,
|
||||||
|
)
|
||||||
|
|
||||||
torch._logging.trace_structured(
|
torch._logging.trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
metadata_fn=lambda: {
|
metadata_fn=lambda: {
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,11 @@ from torch import Tensor
|
||||||
from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions
|
from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions
|
||||||
from torch._dispatch.python import enable_python_dispatcher
|
from torch._dispatch.python import enable_python_dispatcher
|
||||||
from torch._dynamo import compiled_autograd
|
from torch._dynamo import compiled_autograd
|
||||||
from torch._dynamo.utils import dynamo_timed, preserve_rng_state
|
from torch._dynamo.utils import (
|
||||||
|
dynamo_timed,
|
||||||
|
get_chromium_event_logger,
|
||||||
|
preserve_rng_state,
|
||||||
|
)
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
from torch._inductor.utils import BoxedBool
|
from torch._inductor.utils import BoxedBool
|
||||||
from torch._subclasses import FakeTensor, FakeTensorMode
|
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||||
|
|
@ -581,6 +585,13 @@ def _create_aot_dispatcher_function(
|
||||||
enable_python_dispatcher() if shape_env is not None else nullcontext()
|
enable_python_dispatcher() if shape_env is not None else nullcontext()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def try_record_chromium_data(**kwargs):
|
||||||
|
# `backend_compile` only exists as an event if we are compiling with dynamo
|
||||||
|
# In some unit tests we don't use dynamo, so we ignore those cases
|
||||||
|
chromium_log = get_chromium_event_logger()
|
||||||
|
if "backend_compile" in chromium_log.get_stack():
|
||||||
|
chromium_log.add_event_data("backend_compile", **kwargs)
|
||||||
|
|
||||||
# See NOTE: [Deferring tensor pack/unpack hooks until runtime]
|
# See NOTE: [Deferring tensor pack/unpack hooks until runtime]
|
||||||
# If any saved tensor hooks are active, we **don't** want to trace them.
|
# If any saved tensor hooks are active, we **don't** want to trace them.
|
||||||
# Instead, we'll let them run at runtime, around the custom autograd.Function
|
# Instead, we'll let them run at runtime, around the custom autograd.Function
|
||||||
|
|
@ -634,6 +645,9 @@ def _create_aot_dispatcher_function(
|
||||||
req_subclass_dispatch = requires_subclass_dispatch(
|
req_subclass_dispatch = requires_subclass_dispatch(
|
||||||
fake_flat_args, fw_metadata
|
fake_flat_args, fw_metadata
|
||||||
)
|
)
|
||||||
|
try_record_chromium_data(
|
||||||
|
requires_subclass_dispatch=req_subclass_dispatch
|
||||||
|
)
|
||||||
|
|
||||||
output_and_mutation_safe = not any(
|
output_and_mutation_safe = not any(
|
||||||
x.requires_grad
|
x.requires_grad
|
||||||
|
|
@ -752,10 +766,13 @@ or otherwise set torch._functorch.config.functionalize_rng_ops = False."""
|
||||||
if aot_config.is_export:
|
if aot_config.is_export:
|
||||||
# export uses just the "graph bits", whereas the other
|
# export uses just the "graph bits", whereas the other
|
||||||
# two dispatchers include some extra work around handling a runtime epilogue
|
# two dispatchers include some extra work around handling a runtime epilogue
|
||||||
|
try_record_chromium_data(dispatch_mode="export")
|
||||||
return partial(aot_dispatch_export, needs_autograd=needs_autograd)
|
return partial(aot_dispatch_export, needs_autograd=needs_autograd)
|
||||||
elif needs_autograd and not aot_config.pre_dispatch:
|
elif needs_autograd and not aot_config.pre_dispatch:
|
||||||
|
try_record_chromium_data(dispatch_mode="autograd")
|
||||||
return aot_dispatch_autograd
|
return aot_dispatch_autograd
|
||||||
else:
|
else:
|
||||||
|
try_record_chromium_data(dispatch_mode="inference")
|
||||||
return aot_dispatch_base
|
return aot_dispatch_base
|
||||||
|
|
||||||
compiler_fn = choose_dispatcher(needs_autograd, aot_config)
|
compiler_fn = choose_dispatcher(needs_autograd, aot_config)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user