From cb15c1515778499ae801dcf67d55c8bdab4724ef Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Sat, 9 Nov 2024 10:30:47 -0800 Subject: [PATCH] [logging] Overhaul dynamo_timed and CompilationMetrics logging. (#139849) Here's the overview: There's a new contextmanager singleton called MetricsContext. Entering the MetricsContext is how we demarcate the boundary on which we'll create a single CompilationMetrics object, and therefore, a single dynamo_compile log entry. While we're inside the MetricsContext, we can update/set many different metrics. Most importantly: `dynamo_timed` can also update the in-progress MetricsContext. In the proposal here, we tell `dynamo_timed` that we want it to do so by providing the name of the MetricsContext field to increment. There can be many `dynamo_timed` calls in different parts of the code updating different fields. Then when the MetricsContext exits, that's when the logging of everything gathered finally happens. One potential footgun is trying to use `dynamo_timed` when we haven't entered the MetricsContext, but we assert on that problem. Another problem is that we re-enter the context recursively, but we watch for that and do the logging only when the outermost exits. Some specifics: * Introduce MetricsContext - a context manager that on exit, records the CompilationMetrics (which also logs to dynamo_compile). * Completely remove the concept of frame_phase_timing. Instead, update the MetricsContext during compilation, either directly or via dynamo_timed. * Remove some globals we previously used to accumulate counters to later populate a CompilationMetrics. We use CompilationMetrics set/update/increment APIs instead. * `record_compilation_metrics` is now called on exit from MetricsContext. * Populate legacy CompilationMetrics fields right before logging, inside `record_compilation_metrics`. * Remove the one-off `add_remote_cache_time_saved` helper; capture that timing directly into the MetricsContext. And specifically, several changes to dynamo_timed: * "Modernize" the parameters and update all callsites accordingly. * Move the backwards logging of the CompilationMetrics to the backwards compile location. * Add a parameter for which CompilationMetrics field to update Pull Request resolved: https://github.com/pytorch/pytorch/pull/139849 Approved by: https://github.com/ezyang ghstack dependencies: #140094 --- test/dynamo/test_metrics_context.py | 78 ++++ test/dynamo/test_utils.py | 14 +- torch/_dynamo/convert_frame.py | 160 +++----- torch/_dynamo/metrics_context.py | 95 +++++ torch/_dynamo/output_graph.py | 1 + torch/_dynamo/utils.py | 352 ++++++++---------- .../_aot_autograd/autograd_cache.py | 2 +- .../_aot_autograd/runtime_wrappers.py | 54 ++- torch/_inductor/codecache.py | 9 +- torch/_inductor/compile_fx.py | 2 +- torch/_inductor/graph.py | 2 +- torch/_inductor/remote_cache.py | 4 +- torch/_inductor/runtime/runtime_utils.py | 8 +- torch/_logging/_internal.py | 1 - 14 files changed, 456 insertions(+), 326 deletions(-) create mode 100644 test/dynamo/test_metrics_context.py create mode 100644 torch/_dynamo/metrics_context.py diff --git a/test/dynamo/test_metrics_context.py b/test/dynamo/test_metrics_context.py new file mode 100644 index 00000000000..27cf6325246 --- /dev/null +++ b/test/dynamo/test_metrics_context.py @@ -0,0 +1,78 @@ +# Owner(s): ["module: dynamo"] + +from torch._dynamo.metrics_context import MetricsContext +from torch._dynamo.test_case import run_tests, TestCase + + +class TestMetricsContext(TestCase): + def setUp(self): + super().setUp() + self.metrics = {} + + def _on_exit(self, metrics): + # Save away the metrics to be validated in the test. + self.metrics = metrics.copy() + + def test_context_exists(self): + """ + Setting a value without entering the context should raise. + """ + context = MetricsContext(self._on_exit) + with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"): + context.increment("m", 1) + + with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"): + context.set("m", 1) + + with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"): + context.update({"m", 1}) + + def test_nested_context(self): + """ + Only the outermost context should get an on_exit call, and it should + include everything. + """ + context = MetricsContext(self._on_exit) + with context: + with context: + context.set("m1", 1) + self.assertEqual(self.metrics, {}) + context.set("m2", 2) + self.assertEqual(self.metrics, {"m1": 1, "m2": 2}) + + def test_set(self): + """ + Validate various ways to set metrics. + """ + with MetricsContext(self._on_exit) as context: + context.set("m1", 1) + context.set("m2", 2) + context.update({"m3": 3, "m4": 4}) + + self.assertEqual(self.metrics, {"m1": 1, "m2": 2, "m3": 3, "m4": 4}) + + def test_set_disallow_overwrite(self): + """ + Validate set won't overwrite. + """ + with MetricsContext(self._on_exit) as context: + context.set("m1", 1) + with self.assertRaisesRegex(RuntimeError, "already been set"): + context.set("m1", 2) + + self.assertEqual(self.metrics, {"m1": 1}) + + def test_update_disallow_overwrite(self): + """ + Validate update won't overwite. + """ + with MetricsContext(self._on_exit) as context: + context.update({"m1": 1, "m2": 2}) + with self.assertRaisesRegex(RuntimeError, "already been set"): + context.update({"m1": 7, "m3": 3}) + + self.assertEqual(self.metrics, {"m1": 1, "m2": 2}) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index c69cd02cf09..8a8abea173e 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -159,6 +159,7 @@ class TestDynamoTimed(TestCase): '_recursive_post_grad_passes': [0.0, 0.0], '_recursive_pre_grad_passes': [0.0], 'async_compile.wait': [0.0, 0.0], + 'backward._backward_impl': [0.0], 'compile_file': [0.0, 0.0], 'compile_fx..bw_compiler': [0.0], 'compile_fx..fw_compiler_base': [0.0], @@ -174,6 +175,7 @@ class TestDynamoTimed(TestCase): """\ {'backend_compile': 0.0, 'code_gen': 0.0, + 'entire_backward_compile': 0.0, 'entire_frame_compile': 0.0, 'inductor_compile': 0.0, 'total_wall_time': 0.0}""", # noqa: B950 @@ -200,6 +202,7 @@ class TestDynamoTimed(TestCase): {'accumulated_cache_size': 0, 'aot_autograd_cumulative_compile_time_us': 0, 'backend_compile_time_s': 0.0, + 'backward_cumulative_compile_time_us': None, 'cache_size': 0, 'co_filename': None, 'co_firstlineno': None, @@ -210,12 +213,13 @@ class TestDynamoTimed(TestCase): 'config_inline_inbuilt_nn_modules': False, 'config_suppress_errors': False, 'cuda_synchronize_time_us': None, - 'distributed_ephemeral_timeout_us': 0, + 'distributed_ephemeral_timeout_us': None, 'duration_us': 0, 'dynamo_compile_time_before_restart_us': 0, 'dynamo_config': None, 'dynamo_cumulative_compile_time_us': 0, 'dynamo_time_before_restart_s': 0.0, + 'end_time_us': 100, 'entire_frame_compile_time_s': 0.0, 'fail_reason': None, 'fail_type': None, @@ -231,9 +235,10 @@ class TestDynamoTimed(TestCase): 'inductor_compile_time_s': 0.0, 'inductor_cumulative_compile_time_us': 0, 'is_forward': True, + 'log_format_version': 2, 'non_compliant_ops': set(), 'num_triton_bundles': None, - 'remote_cache_time_saved_s': 0, + 'remote_cache_time_saved_s': None, 'remote_fx_graph_cache_get_time_ms': None, 'remote_fx_graph_cache_get_time_us': None, 'remote_fx_graph_cache_put_time_ms': None, @@ -257,6 +262,7 @@ class TestDynamoTimed(TestCase): {'accumulated_cache_size': None, 'aot_autograd_cumulative_compile_time_us': None, 'backend_compile_time_s': None, + 'backward_cumulative_compile_time_us': 0, 'cache_size': None, 'co_filename': None, 'co_firstlineno': None, @@ -273,6 +279,7 @@ class TestDynamoTimed(TestCase): 'dynamo_config': None, 'dynamo_cumulative_compile_time_us': None, 'dynamo_time_before_restart_s': None, + 'end_time_us': 100, 'entire_frame_compile_time_s': None, 'fail_reason': None, 'fail_type': None, @@ -288,6 +295,7 @@ class TestDynamoTimed(TestCase): 'inductor_compile_time_s': 0.0, 'inductor_cumulative_compile_time_us': 0, 'is_forward': False, + 'log_format_version': 2, 'non_compliant_ops': None, 'num_triton_bundles': None, 'remote_cache_time_saved_s': None, @@ -302,7 +310,7 @@ class TestDynamoTimed(TestCase): 'specialize_float': None, 'start_time': None, 'start_time_us': 100, - 'structured_logging_overhead_s': 0.0, + 'structured_logging_overhead_s': None, 'structured_logging_overhead_us': 0, 'triton_compile_time_us': None}""", # noqa: B950 ) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 093b53865d8..55a4f73c2e6 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -30,7 +30,7 @@ import torch import torch._logging from torch._C._dynamo.guards import GlobalStateGuard from torch._dynamo.distributed import get_compile_pg -from torch._dynamo.utils import CompileTimeInstructionCounter +from torch._dynamo.utils import CompileTimeInstructionCounter, get_metrics_context from torch._guards import compile_context, CompileContext, CompileId, tracing from torch._logging import structured from torch._utils_internal import ( @@ -105,12 +105,9 @@ from .symbolic_convert import ( from .trace_rules import is_numpy from .utils import ( CleanupManager, - codecache_metrics, - CompilationMetrics, counters, dynamo_timed, format_bytecode, - frame_phase_timing, gen_record_file_name, get_chromium_event_logger, increment_frame, @@ -118,10 +115,8 @@ from .utils import ( istype, LazyString, orig_code_map, - record_compilation_metrics, reset_graph_break_dup_checker, setup_compile_debug, - to_int_ms, to_int_us, troubleshooting_url, write_record_to_file, @@ -699,7 +694,9 @@ def _compile( with contextlib.ExitStack() as stack: stack.enter_context( dynamo_timed( - "_compile.compile_inner", phase_name="entire_frame_compile" + "_compile.compile_inner", + phase_name="entire_frame_compile", + dynamo_compile_column_us="dynamo_cumulative_compile_time_us", ) ) stack.enter_context( @@ -864,9 +861,11 @@ def _compile( chromium_event_log.reset() chromium_start_time = time.time_ns() chromium_event_log.log_event_start("dynamo", chromium_start_time, {}) + + metrics_context = get_metrics_context() with _use_lazy_graph_module(config.use_lazy_graph_module), compile_context( CompileContext(compile_id) - ): + ), metrics_context: restart_reasons: set[str] = set() # This is shared across restarts mutated_closure_cell_ids: Set[int] = set() @@ -974,7 +973,6 @@ def _compile( fail_user_frame_lineno: Optional[int] = None torch._dynamo.utils.ReinplaceCounters.clear() guarded_code = None - codecache_metrics.clear() try: guarded_code = compile_inner(code, one_graph, hooks, transform) @@ -991,6 +989,7 @@ def _compile( return guarded_code except Exception as e: + # TODO(masnesral): Populating the exception info should be automatic fail_type = type(e).__qualname__ fail_reason = str(e) # NB: e's msg is mutated here to add user stack, but we DON'T want @@ -1040,66 +1039,34 @@ def _compile( if tracer: tracer.output.local_scope = {} - duration_ns = time.time_ns() - start_time_ns + end_time_ns = time.time_ns() + duration_ns = end_time_ns - start_time_ns from .utils import curr_frame frame_key = str(curr_frame) - if ( - fail_reason is None - and output is not None - and frame_key in frame_phase_timing - ): + if fail_reason is None and output is not None: guard_count = len(output.guards) shape_env_guard_count = len(output.shape_env.guards) graph_op_count = output.count_calls() graph_node_count = len(output.graph.nodes) graph_input_count = len(output.placeholders) - entire_frame_compile_time = frame_phase_timing[frame_key].get( - "entire_frame_compile", None - ) - backend_compile_time = frame_phase_timing[frame_key].get( - "backend_compile", None - ) - inductor_compile_time = frame_phase_timing[frame_key].get( - "inductor_compile", None - ) - code_gen_time = frame_phase_timing[frame_key].get("code_gen", None) non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops} compliant_custom_ops = { op.__qualname__ for op in output.compliant_custom_ops } - remote_cache_time_saved = frame_phase_timing[frame_key].get( - "remote_cache_time_saved", 0 - ) - remote_fx_graph_cache_get_time = frame_phase_timing[frame_key].get( - "remote_fx_graph_cache_get", None - ) - remote_fx_graph_cache_put_time = frame_phase_timing[frame_key].get( - "remote_fx_graph_cache_put", None - ) - num_triton_bundles = codecache_metrics.get("num_triton_bundles", None) torch._dynamo.utils.ReinplaceCounters.log() - else: guard_count = None shape_env_guard_count = None graph_op_count = None graph_node_count = None graph_input_count = None - entire_frame_compile_time = None - backend_compile_time = None - inductor_compile_time = None - code_gen_time = None non_compliant_ops = set({}) compliant_custom_ops = set({}) restart_reasons = set() # If compilation failed, the entire time is wasted dynamo_time_before_restart = duration_ns / 1e9 - remote_cache_time_saved = None - remote_fx_graph_cache_get_time = None - remote_fx_graph_cache_put_time = None - num_triton_bundles = None structured_logging_overhead_s = ( torch._logging.get_structured_logging_overhead() @@ -1134,74 +1101,55 @@ def _compile( } config_dict = clean_for_json(config.get_config_copy()) - metrics = CompilationMetrics( - str(compile_id), - frame_key, - code.co_name, - code.co_filename, - code.co_firstlineno, - cache_size.num_cache_entries_with_same_id_matched_objs, - cache_size.num_cache_entries, - guard_count, - shape_env_guard_count, - graph_op_count, - graph_node_count, - graph_input_count, - start_time_ns / 1e9, - entire_frame_compile_time, - backend_compile_time, - inductor_compile_time, - code_gen_time, - fail_type, - fail_reason, - fail_user_frame_filename, - fail_user_frame_lineno, - non_compliant_ops, - compliant_custom_ops, - restart_reasons, - dynamo_time_before_restart, - guarded_code is not None, - remote_cache_time_saved, - structured_logging_overhead_s, - config.suppress_errors, - config.inline_inbuilt_nn_modules, - config.specialize_float, - json.dumps(config_dict), - True, # is_forward - num_triton_bundles, - to_int_ms(remote_fx_graph_cache_get_time), - to_int_ms(remote_fx_graph_cache_put_time), - start_time_us=start_time_ns // 1000, - duration_us=duration_ns // 1000, - dynamo_cumulative_compile_time_us=to_int_us(entire_frame_compile_time), - aot_autograd_cumulative_compile_time_us=to_int_us(backend_compile_time), - inductor_cumulative_compile_time_us=to_int_us(inductor_compile_time), - inductor_code_gen_cumulative_compile_time_us=to_int_us(code_gen_time), - triton_compile_time_us=None, # TODO: instrument - runtime_cudagraphify_time_us=None, # TODO: instrument in separate event - runtime_triton_autotune_time_us=None, # TODO: instrument in separate event - dynamo_compile_time_before_restart_us=to_int_us( + metrics = { + "compile_id": str(compile_id), + "frame_key": frame_key, + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs, + "accumulated_cache_size": cache_size.num_cache_entries, + "guard_count": guard_count, + "shape_env_guard_count": shape_env_guard_count, + "graph_op_count": graph_op_count, + "graph_node_count": graph_node_count, + "graph_input_count": graph_input_count, + # TODO(masnesral): start_time and end_time shouldn't need to be + # populated manually. + "start_time": start_time_ns / 1e9, + "fail_type": fail_type, + "fail_reason": fail_reason, + "fail_user_frame_filename": fail_user_frame_filename, + "fail_user_frame_lineno": fail_user_frame_lineno, + "non_compliant_ops": non_compliant_ops, + "compliant_custom_ops": compliant_custom_ops, + "restart_reasons": restart_reasons, + "dynamo_time_before_restart_s": dynamo_time_before_restart, + "has_guarded_code": guarded_code is not None, + "structured_logging_overhead_s": structured_logging_overhead_s, + "config_suppress_errors": config.suppress_errors, + "config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules, + "specialize_float": config.specialize_float, + "dynamo_config": json.dumps(config_dict), + "is_forward": True, + "start_time_us": start_time_ns // 1000, + "end_time_us": end_time_ns // 1000, + "duration_us": duration_ns // 1000, + "dynamo_compile_time_before_restart_us": to_int_us( dynamo_time_before_restart ), - cuda_synchronize_time_us=None, # TODO: instrument - distributed_ephemeral_timeout_us=to_int_us( - remote_cache_time_saved - ), # TODO: instrument more accurately - structured_logging_overhead_us=to_int_us(structured_logging_overhead_s), - remote_fx_graph_cache_get_time_us=to_int_us( - remote_fx_graph_cache_get_time + "structured_logging_overhead_us": to_int_us( + structured_logging_overhead_s ), - remote_fx_graph_cache_put_time_us=to_int_us( - remote_fx_graph_cache_put_time - ), - ) - record_compilation_metrics(metrics) + } + metrics_context.update_outer(metrics) torch._dynamo.callback_handler.run_end_callbacks() - chromium_event_log.log_event_end( - "dynamo", time.time_ns(), {}, chromium_start_time, True - ) # === END WARNING WARNING WARNING === + chromium_event_log.log_event_end( + "dynamo", time.time_ns(), {}, chromium_start_time, True + ) + class ConvertFrame: def __init__(self, compiler_fn: CompilerFn, hooks: Hooks) -> None: diff --git a/torch/_dynamo/metrics_context.py b/torch/_dynamo/metrics_context.py new file mode 100644 index 00000000000..a51ad52fe9b --- /dev/null +++ b/torch/_dynamo/metrics_context.py @@ -0,0 +1,95 @@ +from typing import Any, Callable, Dict, Optional, Type +from typing_extensions import TypeAlias + + +OnExitType: TypeAlias = Callable[[Dict[str, Any]], None] + + +class MetricsContext: + def __init__(self, on_exit: OnExitType): + """ + Use this class as a contextmanager to create a context under which to accumulate + a set of metrics, e.g., metrics gathered during a compilation. On exit of the + contextmanager, call the provided 'on_exit' function and pass a dictionary of + all metrics set during the lifetime of the contextmanager. + """ + self._on_exit = on_exit + self._metrics: Dict[str, Any] = {} + self._level = 0 + + def __enter__(self) -> "MetricsContext": + """ + Initialize metrics recording. + """ + if self._level == 0: + # In case of recursion, track at the outermost context. + self._metrics = {} + + self._level += 1 + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + _traceback: Any, + ) -> None: + """ + At exit, call the provided on_exit function. + """ + self._level -= 1 + assert self._level >= 0 + if self._level == 0: + self._on_exit(self._metrics) + + def in_progress(self) -> bool: + """ + True if we've entered the context. + """ + return self._level > 0 + + def increment(self, metric: str, value: int) -> None: + """ + Increment a metric by a given amount. + """ + if self._level == 0: + raise RuntimeError(f"Cannot increment {metric} outside of a MetricsContext") + if metric not in self._metrics: + self._metrics[metric] = 0 + self._metrics[metric] += value + + def set(self, metric: str, value: Any) -> None: + """ + Set a metric to a given value. Raises if the metric has been assigned previously + in the current context. + """ + if self._level == 0: + raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext") + if metric in self._metrics: + raise RuntimeError( + f"Metric '{metric}' has already been set in the current context" + ) + self._metrics[metric] = value + + def update(self, values: Dict[str, Any]) -> None: + """ + Set multiple metrics directly. This method does NOT increment. Raises if any + metric has been assigned previously in the current context. + """ + if self._level == 0: + raise RuntimeError("Cannot update metrics outside of a MetricsContext") + existing = self._metrics.keys() & values.keys() + if existing: + raise RuntimeError( + f"Metric(s) {existing} have already been set in the current context" + ) + self._metrics.update(values) + + def update_outer(self, values: Dict[str, Any]) -> None: + """ + Update, but only when at the outermost context. + """ + if self._level == 0: + raise RuntimeError("Cannot update metrics outside of a MetricsContext") + if self._level == 1: + self.update(values) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 67792433b9c..a667538cc35 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1391,6 +1391,7 @@ class OutputGraph: "OutputGraph.call_user_compiler", phase_name="backend_compile", log_pt2_compile_event=True, + dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us", ): return self._call_user_compiler(gm) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 65095e7daa9..9319ddfebe7 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -43,6 +43,7 @@ from typing import ( DefaultDict, Deque, Dict, + Generator, Iterable, Iterator, KeysView, @@ -71,6 +72,7 @@ from torch._C import ( _push_on_torch_function_stack, ) from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.metrics_context import MetricsContext from torch._guards import Source, TracingContext from torch._subclasses.meta_utils import is_sparse_compressed from torch._utils_internal import ( @@ -139,12 +141,10 @@ log = logging.getLogger(__name__) # profiling compilation time by function compilation_time_metrics: Dict[str, List[float]] = {} -# profiling compilation time by frame phase -frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict( - lambda: collections.defaultdict(float) -) - -codecache_metrics: Counter[str] = collections.Counter() +# This supports calculate_time_spent(), which reports cumulative times +# across the process for any "phase" populated by dynamo_timed. Reset if +# reset_frame_count() is called. +cumulative_time_spent_ns: Dict[str, float] = collections.defaultdict(float) timer_counter = itertools.count() @@ -220,7 +220,7 @@ def increment_frame() -> None: # Note: Called for you by dynamo - you almost never ever want to invoke this yourself. def reset_frame_count() -> None: global curr_frame - frame_phase_timing.clear() + cumulative_time_spent_ns.clear() compilation_time_metrics.clear() curr_frame = 0 @@ -233,25 +233,16 @@ def increment_op_count(cnt: int) -> None: op_count += cnt -# Calculate total time spent so far for each phase +# Get the total time in seconds for each "phase" # For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806} def calculate_time_spent() -> Dict[str, float]: - total_wall_time = 0.0 total_by_key = {} - for timings in frame_phase_timing.values(): - total_wall_time += timings.get( - "entire_frame_compile", timings.get("inductor_compile", 0) - ) - - for key, timing in timings.items(): - if key not in total_by_key: - total_by_key[key] = timing - else: - total_by_key[key] += timing - - if total_by_key: - total_by_key["total_wall_time"] = total_wall_time + for phase, timing in cumulative_time_spent_ns.items(): + total_by_key[phase] = timing / 1e9 + total_by_key["total_wall_time"] = total_by_key.get( + "entire_frame_compile", 0 + ) + total_by_key.get("entire_backward_compile", 0) return total_by_key @@ -270,188 +261,124 @@ def print_time_report() -> None: print(out) -def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None: - frame_phase_timing[key][phase_name] += time_spent +# Use the following singleton to capture and log CompilationMetrics. Entering the context +# manager allocates a new record to be logged when it exits. (You should not need to use +# this directly unless you introduce a new code path where compilation metrics would be +# gathered). While compiling, use the setters or timer in MetricsContext to update fields +# in the current context. For example: +# +# To set a single field once (use overwrite=True to overwrite): +# get_metrics_context().set("metric_name", value) +# +# To set multiple fields at once (use overwrite=True to overwrite): +# get_metrics_context().update({"name1": val1, "name2": val2}) +# +# To increment an integer field: +# get_metrics_context().increment("metric_name", value) +# +# To record execution time, MetricsContext works with dynamo_timed: +# def foo(...): +# # Updates the "metric_us" field. +# with dynamo_timed("metric", dynamo_compile_column_us="metric_us") +# ... +# +_METRICS_CONTEXT: MetricsContext -# Use frame_phase_timing to record remote_cache_time_saved -# This follows the same principles of key as the other frame phase timings, -# but is incremented by FxGraphCache (and later AOTAutogradCache) directly -def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -> None: - key = None - if is_backward: - # Use compile id as the frame key for backwards compilation - key = str(torch._guards.CompileContext.current_compile_id()) - else: - key = str(curr_frame) - # Convert to seconds (as a float) - time_saved = time_saved_ns / 1e9 - _add_time_spent(key, "remote_cache_time_saved", time_saved) - - -# dynamo_timed is a context manager -# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics -# where the key is the functions name. -# For example: -# -# def _foo(...): -# with dynamo_timed("_foo"): -# ... -# -# Would show up as an entry in our timing dict: -# OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])]) -# This is extremely useful for granular debugging. -# -# Although it is tempting to use dynamo_timed as a decorator, please do not. -# In its decorator form it makes cProfile traces less useful as dynamo_timed -# suddenly becomes a bottleneck for lots of function calls (as only one parent -# pointer is recorded). -# -# For a higher-level mode, pass a phase_name into dynamo_timed -# phase_names record an extra record into a separate compilation timing structure, -# one keyed on frame+name rather than function. -# The frame is incremented outside of this function, in def increment_frame() above. -# `fwd_only` is used to identify if this phase or function is only called -# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`. -# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs. +def get_metrics_context() -> MetricsContext: + return _METRICS_CONTEXT @contextmanager def dynamo_timed( key: str, + # TODO(masneral): Deprecate this param. phase_name: Optional[str] = None, - log_pt2_compile_event: bool = False, # Whether or not to log it to internal pt2 compile event + log_pt2_compile_event: bool = False, + # TODO(masnesral): fwd_only is ignored. Remove it. fwd_only: bool = True, -): - chromium_log: ChromiumEventLogger = get_chromium_event_logger() + metadata: Optional[Dict[str, object]] = None, + dynamo_compile_column_us: Optional[str] = None, +) -> Generator[Any, None, None]: + """ + dynamo_timed is a context manager + By wrapping a function in dynamo_timed, we can get a few things: + + 1) Log timings to pt2_compile_events. + 2) Log timings to CompilationMetrics (dynamo_compile). + 3) Chromium events. + 4) Storing a record in compilation_time_metrics + For example: + + def _foo(...): + with dynamo_timed("_foo"): + ... + + Would show up as an entry in our timing dict: + OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])]) + This is extremely useful for granular debugging. + + Although it is tempting to use dynamo_timed as a decorator, please do not. + In its decorator form it makes cProfile traces less useful as dynamo_timed + suddenly becomes a bottleneck for lots of function calls (as only one parent + pointer is recorded). + + Params: + - key: key into compile_time_metrics. If phase_name is not provided, this is + also the event name used for pt2_compile_events logs and chromium events. + - phase_name: Optional override for the event name. + - log_pt2_compile_event: Whether to log a pt2 compile event internally. + - metadata: Extra metadata to put in pt2_compile_events. + - dynamo_compile_column_us: If provided, updates the specified CompilationMetrics + field to be logged to dyname_compile column. We expect all columns to be _us; + therefore, the field name must end with "_us". + """ + # We're standardizing on microseconds for dynamo_compile timings. + if dynamo_compile_column_us is not None: + assert dynamo_compile_column_us.endswith("_us") + + if phase_name: + event_name = phase_name + fn_name = key + else: + event_name = key + fn_name = None + if key not in compilation_time_metrics: compilation_time_metrics[key] = [] - fail_type: Optional[str] = None - fail_reason: Optional[str] = None - time_spent = float("-inf") + event_metadata = {} + if metadata: + event_metadata.update(metadata) + if fn_name: + event_metadata.update({"fn_name": fn_name}) + + chromium_log: ChromiumEventLogger = get_chromium_event_logger() start_ns = time.time_ns() + chromium_log.log_event_start(event_name, start_ns, event_metadata) + try: with torch.profiler.record_function(f"{key} (dynamo_timed)"): - t0 = time.time() - if phase_name: - chromium_log.log_event_start(phase_name, start_ns, {"fn_name": key}) - else: - chromium_log.log_event_start(key, start_ns, {}) yield - time_spent = time.time() - t0 - compilation_time_metrics[key].append(time_spent) - except Exception as e: - fail_type = str(type(e)) - fail_reason = str(e) - raise finally: end_ns = time.time_ns() - # Always log the end event even on exception - if phase_name: - chromium_log.log_event_end( - phase_name, - end_ns, - {}, - start_ns, - log_pt2_compile_event, - ) - else: - chromium_log.log_event_end(key, end_ns, {}, start_ns, log_pt2_compile_event) - # Only record backward compilation metrics if phase_name is not None! - if phase_name: - frame_key = str(curr_frame) - # fwd only compilation stages: entire_frame_compile, backend_compile, aotdispatch. - # use frame_key as time aggregation key. - if fwd_only and fail_type is None: - _add_time_spent(frame_key, phase_name, time_spent) - else: - # fwd + bwd compilation stages: inductor_compile, code_gen. - # use frame_key as time aggregation key for fwd graphs; - # use compile_id as time aggregation key for bwd graphs. - if torch._guards.TracingContext.try_get() is not None: - aot_graph_name = str( - torch._guards.TracingContext.get().aot_graph_name - ) - if ( - "forward" in aot_graph_name or "inference" in aot_graph_name - ) and fail_type is None: - _add_time_spent(frame_key, phase_name, time_spent) - elif "backward" in aot_graph_name: - compile_id = str( - torch._guards.CompileContext.current_compile_id() - ) - if fail_type is None: - _add_time_spent(compile_id, phase_name, time_spent) - - # log backward compilation metrics at the end of `inductor_compile` of bwd graph, - # one record for one bwd graph. - if phase_name == "inductor_compile": - if fail_type is None: - inductor_compile_time = frame_phase_timing[ - compile_id - ].get("inductor_compile", None) - code_gen_time = frame_phase_timing[compile_id].get( - "code_gen", None - ) - remote_cache_time_saved = frame_phase_timing[ - compile_id - ].get("remote_cache_time_saved", None) - remote_fx_graph_cache_get_time = frame_phase_timing[ - compile_id - ].get("remote_fx_graph_cache_get", None) - remote_fx_graph_cache_put_time = frame_phase_timing[ - compile_id - ].get("remote_fx_graph_cache_put", None) - else: - inductor_compile_time = None - code_gen_time = None - remote_cache_time_saved = None - remote_fx_graph_cache_get_time = None - remote_fx_graph_cache_put_time = None - structured_logging_overhead_s = ( - torch._logging.get_structured_logging_overhead() - ) - metrics = CompilationMetrics( - compile_id=compile_id, - inductor_compile_time_s=inductor_compile_time, - code_gen_time_s=code_gen_time, - fail_type=fail_type, - fail_reason=fail_reason, - remote_cache_time_saved_s=remote_cache_time_saved, - structured_logging_overhead_s=structured_logging_overhead_s, - is_forward=False, # is_forward - num_triton_bundles=codecache_metrics.get( - "num_triton_bundles", None - ), - remote_fx_graph_cache_get_time_ms=to_int_ms( - remote_fx_graph_cache_get_time - ), - remote_fx_graph_cache_put_time_ms=to_int_ms( - remote_fx_graph_cache_put_time - ), - start_time_us=start_ns // 1000, - duration_us=(end_ns - start_ns) // 1000, - inductor_cumulative_compile_time_us=to_int_us( - inductor_compile_time - ), - inductor_code_gen_cumulative_compile_time_us=to_int_us( - code_gen_time - ), - distributed_ephemeral_timeout_us=to_int_us( - remote_cache_time_saved - ), # TODO: instrument more accurately - structured_logging_overhead_us=to_int_us( - structured_logging_overhead_s - ), - remote_fx_graph_cache_get_time_us=to_int_us( - remote_fx_graph_cache_get_time - ), - remote_fx_graph_cache_put_time_us=to_int_us( - remote_fx_graph_cache_put_time - ), - ) - record_compilation_metrics(metrics) + time_spent_ns = end_ns - start_ns + compilation_time_metrics[key].append(time_spent_ns / 1e9) + chromium_log.log_event_end( + event_name, end_ns, {}, start_ns, log_pt2_compile_event + ) + if dynamo_compile_column_us: + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + metrics_context.increment( + dynamo_compile_column_us, time_spent_ns // 1000 + ) + # TODO: the events that we capture in calculate_time_spent() seem a little + # arbitrary. Currently, it's only those fields that are present in + # CompilationMetrics (but note that we accumulate by the associated event + # name, not the field name in CompilationMetrics). Do we want to keep it + # this way? + cumulative_time_spent_ns[event_name] += time_spent_ns @overload @@ -866,6 +793,11 @@ def to_int_us(v: Optional[float]) -> Optional[int]: return None if v is None else int(v * 1_000_000) +# Version field added to every log. Increment to make it easier to distinguish new +# vs. old entries when you make a substantive change to how the logs are populated. +LOG_FORMAT_VERSION = 2 + + @dataclasses.dataclass class CompilationMetrics: compile_id: Optional[str] = None @@ -913,15 +845,18 @@ class CompilationMetrics: aot_autograd_cumulative_compile_time_us: Optional[int] = None inductor_cumulative_compile_time_us: Optional[int] = None inductor_code_gen_cumulative_compile_time_us: Optional[int] = None - triton_compile_time_us: Optional[int] = None - runtime_cudagraphify_time_us: Optional[int] = None - runtime_triton_autotune_time_us: Optional[int] = None + triton_compile_time_us: Optional[int] = None # TODO: instrument + runtime_cudagraphify_time_us: Optional[int] = None # TODO: instrument + runtime_triton_autotune_time_us: Optional[int] = None # TODO: instrument dynamo_compile_time_before_restart_us: Optional[int] = None - cuda_synchronize_time_us: Optional[int] = None + cuda_synchronize_time_us: Optional[int] = None # TODO: instrument distributed_ephemeral_timeout_us: Optional[int] = None structured_logging_overhead_us: Optional[int] = None remote_fx_graph_cache_get_time_us: Optional[int] = None remote_fx_graph_cache_put_time_us: Optional[int] = None + backward_cumulative_compile_time_us: Optional[int] = None + end_time_us: Optional[int] = None + log_format_version: int = LOG_FORMAT_VERSION DEFAULT_COMPILATION_METRICS_LIMIT = 64 @@ -969,8 +904,32 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics): ) -def record_compilation_metrics(compilation_metrics: CompilationMetrics): - global _compilation_metrics +def record_compilation_metrics(metrics: Dict[str, Any]): + # TODO: Temporary; populate legacy fields from their replacements. + # Remove when we decide we can really deprecate them. + def us_to_s(field): + metric = metrics.get(field, None) + return metric / 1e6 if metric is not None else None + + def us_to_ms(field): + metric = metrics.get(field, None) + return metric // 1000 if metric is not None else None + + legacy_metrics = { + "entire_frame_compile_time_s": us_to_s("dynamo_cumulative_compile_time_us"), + "backend_compile_time_s": us_to_s("aot_autograd_cumulative_compile_time_us"), + "inductor_compile_time_s": us_to_s("inductor_cumulative_compile_time_us"), + "code_gen_time_s": us_to_s("inductor_code_gen_cumulative_compile_time_us"), + "remote_cache_time_saved_s": us_to_s("distributed_ephemeral_timeout_us"), + "remote_fx_graph_cache_get_time_ms": us_to_ms( + "remote_fx_graph_cache_get_time_us" + ), + "remote_fx_graph_cache_put_time_ms": us_to_ms( + "remote_fx_graph_cache_put_time_us" + ), + } + + compilation_metrics = CompilationMetrics(**{**metrics, **legacy_metrics}) _compilation_metrics.append(compilation_metrics) if compilation_metrics.is_forward: name = "compilation_metrics" @@ -979,10 +938,7 @@ def record_compilation_metrics(compilation_metrics: CompilationMetrics): name = "bwd_compilation_metrics" torch._logging.trace_structured( name, - lambda: { - k: list(v) if isinstance(v, set) else v - for k, v in dataclasses.asdict(compilation_metrics).items() - }, + lambda: {k: list(v) if isinstance(v, set) else v for k, v in metrics.items()}, # NB: Because compilation metrics *includes* the logging overhead time, # we can't both *measure* the logging overhead of compilation metrics # without making it inconsistent with compilation metrics itself, so @@ -993,6 +949,10 @@ def record_compilation_metrics(compilation_metrics: CompilationMetrics): log_compilation_event(compilation_metrics) +# record_compilation_metrics is called by the singleton MetricsContext exit handler. +_METRICS_CONTEXT = MetricsContext(on_exit=record_compilation_metrics) + + def set_compilation_metrics_limit(new_size: int) -> None: global _compilation_metrics while len(_compilation_metrics) > new_size: diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index bb674088293..13d7a1460d5 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -632,7 +632,7 @@ class AOTAutogradCache: ) # TODO: should we use the same field for remote cache time saved for both # FXGraphCache and AOTAutogradCache? - # add_remote_cache_time_saved(time_saved_ns, is_backward=False) + # get_metrics_context().increment(...) if ( ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( time_saved_ns diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 206bab56924..f421f30cb44 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -10,6 +10,7 @@ import builtins import collections import itertools import pprint +import time from contextlib import nullcontext from dataclasses import dataclass, field from functools import wraps @@ -18,6 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.utils.dlpack from torch import Tensor +from torch._dynamo.utils import dynamo_timed, get_metrics_context, to_int_us from torch._guards import ( compile_context, CompileContext, @@ -2002,17 +2004,49 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa context = torch._C._DisableAutocast if disable_amp else nullcontext with tracing(saved_context), compile_context( saved_compile_context - ), context(), track_graph_compiling(aot_config, "backward"): - CompiledFunction.compiled_bw = aot_config.bw_compiler( - bw_module, placeholder_list - ) - # Maybe save cache entry - if try_save_cache_entry is not None: - try_save_cache_entry( - CompiledFunction.compiled_bw, - fw_metadata, - aot_config, + ), context(), track_graph_compiling( + aot_config, "backward" + ), get_metrics_context(), dynamo_timed( + "backward._backward_impl", + phase_name="entire_backward_compile", + dynamo_compile_column_us="backward_cumulative_compile_time_us", + ): + fail_type: Optional[str] = None + fail_reason: Optional[str] = None + start_ns = time.time_ns() + try: + CompiledFunction.compiled_bw = aot_config.bw_compiler( + bw_module, placeholder_list ) + # Maybe save cache entry + if try_save_cache_entry is not None: + try_save_cache_entry( + CompiledFunction.compiled_bw, + fw_metadata, + aot_config, + ) + except Exception as e: + # TODO(masnesral): Populating the exception info should be automatic. + fail_type = type(e).__qualname__ + fail_reason = str(e) + finally: + # TODO(masnesral): Populating time fields should be automatic. + end_ns = time.time_ns() + metrics = { + "compile_id": str( + torch._guards.CompileContext.current_compile_id() + ), + "fail_type": fail_type, + "fail_reason": fail_reason, + "is_forward": False, + "start_time_us": start_ns // 1000, + "end_time_us": end_ns // 1000, + "duration_us": (end_ns - start_ns) // 1000, + "structured_logging_overhead_us": to_int_us( + torch._logging.get_structured_logging_overhead(), + ), + } + get_metrics_context().update_outer(metrics) if ( torch._functorch.config.donated_buffer diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 2371ae38bfb..821d67514a2 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -54,11 +54,10 @@ import torch import torch.distributed as dist from torch import SymInt, Tensor from torch._dynamo.utils import ( - add_remote_cache_time_saved, - codecache_metrics, counters, dynamo_timed, get_chromium_event_logger, + get_metrics_context, ) from torch._inductor import config, exc, metrics from torch._inductor.codegen.cuda import cuda_env @@ -1152,7 +1151,7 @@ class FxGraphCache: "inductor_compile", cached_kernel_names=meta.cached_kernel_names ) if len(meta.cached_kernel_names) > 0: - codecache_metrics["num_triton_bundles"] += 1 + get_metrics_context().increment("num_triton_bundles", 1) inductor_meta = autotune_cache.inductor_meta_from_config() AutotuneCacheBundler.begin_compile(inductor_meta, code=code) @@ -1449,7 +1448,9 @@ class FxGraphCache: if (time_saved_ns := compiled_graph._time_taken_ns) is not None: cache_info["time_saved_ns"] = time_saved_ns - add_remote_cache_time_saved(time_saved_ns, is_backward) + get_metrics_context().increment( + "distributed_ephemeral_timeout_us", time_saved_ns // 1000 + ) if ( ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( time_saved_ns diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 1cd8d09043a..cd997c30e91 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -567,7 +567,7 @@ def compile_fx_inner( "compile_fx_inner", phase_name="inductor_compile", log_pt2_compile_event=True, - fwd_only=False, + dynamo_compile_column_us="inductor_cumulative_compile_time_us", ) ) # NB: Why is this the dynamo_compile counter? The rule here is that diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index fb96fb7c69b..e89897b87ad 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1949,7 +1949,7 @@ class GraphLowering(torch.fx.Interpreter): "GraphLowering.compile_to_module", phase_name="code_gen", log_pt2_compile_event=True, - fwd_only=False, + dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us", ): return self._compile_to_module() diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index d0359950064..4e53b920a42 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -45,14 +45,14 @@ remote_fx_cache_get_timed = functools.partial( "FbRemoteFxGraphCache.get", phase_name="remote_fx_graph_cache_get", log_pt2_compile_event=False, - fwd_only=False, + dynamo_compile_column_us="remote_fx_graph_cache_get_time_us", ) remote_fx_cache_put_timed = functools.partial( dynamo_timed, "FbRemoteFxGraphCache.put", phase_name="remote_fx_graph_cache_put", log_pt2_compile_event=False, - fwd_only=False, + dynamo_compile_column_us="remote_fx_graph_cache_put_time_us", ) diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 0f4ed04ebe5..a2a39e31f15 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -135,5 +135,11 @@ try: except AttributeError: # Compile workers only have a mock version of torch @contextlib.contextmanager - def dynamo_timed(key, phase_name=None, fwd_only=True): + def dynamo_timed( + key, + phase_name=None, + fwd_only=True, + metadata=None, + dynamo_compile_column_us=None, + ): yield diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 70bbb27bfa2..a31ea0c198c 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1099,7 +1099,6 @@ class LazyString: structured_logging_overhead: Dict[str, float] = defaultdict(float) -# Same principle as add_remote_cache_time_saved, but do it for structured logging def add_structured_logging_overhead(time_spent: float) -> None: global structured_logging_overhead key = None