diff --git a/test/dynamo/test_metrics_context.py b/test/dynamo/test_metrics_context.py new file mode 100644 index 00000000000..27cf6325246 --- /dev/null +++ b/test/dynamo/test_metrics_context.py @@ -0,0 +1,78 @@ +# Owner(s): ["module: dynamo"] + +from torch._dynamo.metrics_context import MetricsContext +from torch._dynamo.test_case import run_tests, TestCase + + +class TestMetricsContext(TestCase): + def setUp(self): + super().setUp() + self.metrics = {} + + def _on_exit(self, metrics): + # Save away the metrics to be validated in the test. + self.metrics = metrics.copy() + + def test_context_exists(self): + """ + Setting a value without entering the context should raise. + """ + context = MetricsContext(self._on_exit) + with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"): + context.increment("m", 1) + + with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"): + context.set("m", 1) + + with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"): + context.update({"m", 1}) + + def test_nested_context(self): + """ + Only the outermost context should get an on_exit call, and it should + include everything. + """ + context = MetricsContext(self._on_exit) + with context: + with context: + context.set("m1", 1) + self.assertEqual(self.metrics, {}) + context.set("m2", 2) + self.assertEqual(self.metrics, {"m1": 1, "m2": 2}) + + def test_set(self): + """ + Validate various ways to set metrics. + """ + with MetricsContext(self._on_exit) as context: + context.set("m1", 1) + context.set("m2", 2) + context.update({"m3": 3, "m4": 4}) + + self.assertEqual(self.metrics, {"m1": 1, "m2": 2, "m3": 3, "m4": 4}) + + def test_set_disallow_overwrite(self): + """ + Validate set won't overwrite. + """ + with MetricsContext(self._on_exit) as context: + context.set("m1", 1) + with self.assertRaisesRegex(RuntimeError, "already been set"): + context.set("m1", 2) + + self.assertEqual(self.metrics, {"m1": 1}) + + def test_update_disallow_overwrite(self): + """ + Validate update won't overwite. + """ + with MetricsContext(self._on_exit) as context: + context.update({"m1": 1, "m2": 2}) + with self.assertRaisesRegex(RuntimeError, "already been set"): + context.update({"m1": 7, "m3": 3}) + + self.assertEqual(self.metrics, {"m1": 1, "m2": 2}) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index c69cd02cf09..8a8abea173e 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -159,6 +159,7 @@ class TestDynamoTimed(TestCase): '_recursive_post_grad_passes': [0.0, 0.0], '_recursive_pre_grad_passes': [0.0], 'async_compile.wait': [0.0, 0.0], + 'backward._backward_impl': [0.0], 'compile_file': [0.0, 0.0], 'compile_fx..bw_compiler': [0.0], 'compile_fx..fw_compiler_base': [0.0], @@ -174,6 +175,7 @@ class TestDynamoTimed(TestCase): """\ {'backend_compile': 0.0, 'code_gen': 0.0, + 'entire_backward_compile': 0.0, 'entire_frame_compile': 0.0, 'inductor_compile': 0.0, 'total_wall_time': 0.0}""", # noqa: B950 @@ -200,6 +202,7 @@ class TestDynamoTimed(TestCase): {'accumulated_cache_size': 0, 'aot_autograd_cumulative_compile_time_us': 0, 'backend_compile_time_s': 0.0, + 'backward_cumulative_compile_time_us': None, 'cache_size': 0, 'co_filename': None, 'co_firstlineno': None, @@ -210,12 +213,13 @@ class TestDynamoTimed(TestCase): 'config_inline_inbuilt_nn_modules': False, 'config_suppress_errors': False, 'cuda_synchronize_time_us': None, - 'distributed_ephemeral_timeout_us': 0, + 'distributed_ephemeral_timeout_us': None, 'duration_us': 0, 'dynamo_compile_time_before_restart_us': 0, 'dynamo_config': None, 'dynamo_cumulative_compile_time_us': 0, 'dynamo_time_before_restart_s': 0.0, + 'end_time_us': 100, 'entire_frame_compile_time_s': 0.0, 'fail_reason': None, 'fail_type': None, @@ -231,9 +235,10 @@ class TestDynamoTimed(TestCase): 'inductor_compile_time_s': 0.0, 'inductor_cumulative_compile_time_us': 0, 'is_forward': True, + 'log_format_version': 2, 'non_compliant_ops': set(), 'num_triton_bundles': None, - 'remote_cache_time_saved_s': 0, + 'remote_cache_time_saved_s': None, 'remote_fx_graph_cache_get_time_ms': None, 'remote_fx_graph_cache_get_time_us': None, 'remote_fx_graph_cache_put_time_ms': None, @@ -257,6 +262,7 @@ class TestDynamoTimed(TestCase): {'accumulated_cache_size': None, 'aot_autograd_cumulative_compile_time_us': None, 'backend_compile_time_s': None, + 'backward_cumulative_compile_time_us': 0, 'cache_size': None, 'co_filename': None, 'co_firstlineno': None, @@ -273,6 +279,7 @@ class TestDynamoTimed(TestCase): 'dynamo_config': None, 'dynamo_cumulative_compile_time_us': None, 'dynamo_time_before_restart_s': None, + 'end_time_us': 100, 'entire_frame_compile_time_s': None, 'fail_reason': None, 'fail_type': None, @@ -288,6 +295,7 @@ class TestDynamoTimed(TestCase): 'inductor_compile_time_s': 0.0, 'inductor_cumulative_compile_time_us': 0, 'is_forward': False, + 'log_format_version': 2, 'non_compliant_ops': None, 'num_triton_bundles': None, 'remote_cache_time_saved_s': None, @@ -302,7 +310,7 @@ class TestDynamoTimed(TestCase): 'specialize_float': None, 'start_time': None, 'start_time_us': 100, - 'structured_logging_overhead_s': 0.0, + 'structured_logging_overhead_s': None, 'structured_logging_overhead_us': 0, 'triton_compile_time_us': None}""", # noqa: B950 ) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 093b53865d8..55a4f73c2e6 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -30,7 +30,7 @@ import torch import torch._logging from torch._C._dynamo.guards import GlobalStateGuard from torch._dynamo.distributed import get_compile_pg -from torch._dynamo.utils import CompileTimeInstructionCounter +from torch._dynamo.utils import CompileTimeInstructionCounter, get_metrics_context from torch._guards import compile_context, CompileContext, CompileId, tracing from torch._logging import structured from torch._utils_internal import ( @@ -105,12 +105,9 @@ from .symbolic_convert import ( from .trace_rules import is_numpy from .utils import ( CleanupManager, - codecache_metrics, - CompilationMetrics, counters, dynamo_timed, format_bytecode, - frame_phase_timing, gen_record_file_name, get_chromium_event_logger, increment_frame, @@ -118,10 +115,8 @@ from .utils import ( istype, LazyString, orig_code_map, - record_compilation_metrics, reset_graph_break_dup_checker, setup_compile_debug, - to_int_ms, to_int_us, troubleshooting_url, write_record_to_file, @@ -699,7 +694,9 @@ def _compile( with contextlib.ExitStack() as stack: stack.enter_context( dynamo_timed( - "_compile.compile_inner", phase_name="entire_frame_compile" + "_compile.compile_inner", + phase_name="entire_frame_compile", + dynamo_compile_column_us="dynamo_cumulative_compile_time_us", ) ) stack.enter_context( @@ -864,9 +861,11 @@ def _compile( chromium_event_log.reset() chromium_start_time = time.time_ns() chromium_event_log.log_event_start("dynamo", chromium_start_time, {}) + + metrics_context = get_metrics_context() with _use_lazy_graph_module(config.use_lazy_graph_module), compile_context( CompileContext(compile_id) - ): + ), metrics_context: restart_reasons: set[str] = set() # This is shared across restarts mutated_closure_cell_ids: Set[int] = set() @@ -974,7 +973,6 @@ def _compile( fail_user_frame_lineno: Optional[int] = None torch._dynamo.utils.ReinplaceCounters.clear() guarded_code = None - codecache_metrics.clear() try: guarded_code = compile_inner(code, one_graph, hooks, transform) @@ -991,6 +989,7 @@ def _compile( return guarded_code except Exception as e: + # TODO(masnesral): Populating the exception info should be automatic fail_type = type(e).__qualname__ fail_reason = str(e) # NB: e's msg is mutated here to add user stack, but we DON'T want @@ -1040,66 +1039,34 @@ def _compile( if tracer: tracer.output.local_scope = {} - duration_ns = time.time_ns() - start_time_ns + end_time_ns = time.time_ns() + duration_ns = end_time_ns - start_time_ns from .utils import curr_frame frame_key = str(curr_frame) - if ( - fail_reason is None - and output is not None - and frame_key in frame_phase_timing - ): + if fail_reason is None and output is not None: guard_count = len(output.guards) shape_env_guard_count = len(output.shape_env.guards) graph_op_count = output.count_calls() graph_node_count = len(output.graph.nodes) graph_input_count = len(output.placeholders) - entire_frame_compile_time = frame_phase_timing[frame_key].get( - "entire_frame_compile", None - ) - backend_compile_time = frame_phase_timing[frame_key].get( - "backend_compile", None - ) - inductor_compile_time = frame_phase_timing[frame_key].get( - "inductor_compile", None - ) - code_gen_time = frame_phase_timing[frame_key].get("code_gen", None) non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops} compliant_custom_ops = { op.__qualname__ for op in output.compliant_custom_ops } - remote_cache_time_saved = frame_phase_timing[frame_key].get( - "remote_cache_time_saved", 0 - ) - remote_fx_graph_cache_get_time = frame_phase_timing[frame_key].get( - "remote_fx_graph_cache_get", None - ) - remote_fx_graph_cache_put_time = frame_phase_timing[frame_key].get( - "remote_fx_graph_cache_put", None - ) - num_triton_bundles = codecache_metrics.get("num_triton_bundles", None) torch._dynamo.utils.ReinplaceCounters.log() - else: guard_count = None shape_env_guard_count = None graph_op_count = None graph_node_count = None graph_input_count = None - entire_frame_compile_time = None - backend_compile_time = None - inductor_compile_time = None - code_gen_time = None non_compliant_ops = set({}) compliant_custom_ops = set({}) restart_reasons = set() # If compilation failed, the entire time is wasted dynamo_time_before_restart = duration_ns / 1e9 - remote_cache_time_saved = None - remote_fx_graph_cache_get_time = None - remote_fx_graph_cache_put_time = None - num_triton_bundles = None structured_logging_overhead_s = ( torch._logging.get_structured_logging_overhead() @@ -1134,74 +1101,55 @@ def _compile( } config_dict = clean_for_json(config.get_config_copy()) - metrics = CompilationMetrics( - str(compile_id), - frame_key, - code.co_name, - code.co_filename, - code.co_firstlineno, - cache_size.num_cache_entries_with_same_id_matched_objs, - cache_size.num_cache_entries, - guard_count, - shape_env_guard_count, - graph_op_count, - graph_node_count, - graph_input_count, - start_time_ns / 1e9, - entire_frame_compile_time, - backend_compile_time, - inductor_compile_time, - code_gen_time, - fail_type, - fail_reason, - fail_user_frame_filename, - fail_user_frame_lineno, - non_compliant_ops, - compliant_custom_ops, - restart_reasons, - dynamo_time_before_restart, - guarded_code is not None, - remote_cache_time_saved, - structured_logging_overhead_s, - config.suppress_errors, - config.inline_inbuilt_nn_modules, - config.specialize_float, - json.dumps(config_dict), - True, # is_forward - num_triton_bundles, - to_int_ms(remote_fx_graph_cache_get_time), - to_int_ms(remote_fx_graph_cache_put_time), - start_time_us=start_time_ns // 1000, - duration_us=duration_ns // 1000, - dynamo_cumulative_compile_time_us=to_int_us(entire_frame_compile_time), - aot_autograd_cumulative_compile_time_us=to_int_us(backend_compile_time), - inductor_cumulative_compile_time_us=to_int_us(inductor_compile_time), - inductor_code_gen_cumulative_compile_time_us=to_int_us(code_gen_time), - triton_compile_time_us=None, # TODO: instrument - runtime_cudagraphify_time_us=None, # TODO: instrument in separate event - runtime_triton_autotune_time_us=None, # TODO: instrument in separate event - dynamo_compile_time_before_restart_us=to_int_us( + metrics = { + "compile_id": str(compile_id), + "frame_key": frame_key, + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs, + "accumulated_cache_size": cache_size.num_cache_entries, + "guard_count": guard_count, + "shape_env_guard_count": shape_env_guard_count, + "graph_op_count": graph_op_count, + "graph_node_count": graph_node_count, + "graph_input_count": graph_input_count, + # TODO(masnesral): start_time and end_time shouldn't need to be + # populated manually. + "start_time": start_time_ns / 1e9, + "fail_type": fail_type, + "fail_reason": fail_reason, + "fail_user_frame_filename": fail_user_frame_filename, + "fail_user_frame_lineno": fail_user_frame_lineno, + "non_compliant_ops": non_compliant_ops, + "compliant_custom_ops": compliant_custom_ops, + "restart_reasons": restart_reasons, + "dynamo_time_before_restart_s": dynamo_time_before_restart, + "has_guarded_code": guarded_code is not None, + "structured_logging_overhead_s": structured_logging_overhead_s, + "config_suppress_errors": config.suppress_errors, + "config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules, + "specialize_float": config.specialize_float, + "dynamo_config": json.dumps(config_dict), + "is_forward": True, + "start_time_us": start_time_ns // 1000, + "end_time_us": end_time_ns // 1000, + "duration_us": duration_ns // 1000, + "dynamo_compile_time_before_restart_us": to_int_us( dynamo_time_before_restart ), - cuda_synchronize_time_us=None, # TODO: instrument - distributed_ephemeral_timeout_us=to_int_us( - remote_cache_time_saved - ), # TODO: instrument more accurately - structured_logging_overhead_us=to_int_us(structured_logging_overhead_s), - remote_fx_graph_cache_get_time_us=to_int_us( - remote_fx_graph_cache_get_time + "structured_logging_overhead_us": to_int_us( + structured_logging_overhead_s ), - remote_fx_graph_cache_put_time_us=to_int_us( - remote_fx_graph_cache_put_time - ), - ) - record_compilation_metrics(metrics) + } + metrics_context.update_outer(metrics) torch._dynamo.callback_handler.run_end_callbacks() - chromium_event_log.log_event_end( - "dynamo", time.time_ns(), {}, chromium_start_time, True - ) # === END WARNING WARNING WARNING === + chromium_event_log.log_event_end( + "dynamo", time.time_ns(), {}, chromium_start_time, True + ) + class ConvertFrame: def __init__(self, compiler_fn: CompilerFn, hooks: Hooks) -> None: diff --git a/torch/_dynamo/metrics_context.py b/torch/_dynamo/metrics_context.py new file mode 100644 index 00000000000..a51ad52fe9b --- /dev/null +++ b/torch/_dynamo/metrics_context.py @@ -0,0 +1,95 @@ +from typing import Any, Callable, Dict, Optional, Type +from typing_extensions import TypeAlias + + +OnExitType: TypeAlias = Callable[[Dict[str, Any]], None] + + +class MetricsContext: + def __init__(self, on_exit: OnExitType): + """ + Use this class as a contextmanager to create a context under which to accumulate + a set of metrics, e.g., metrics gathered during a compilation. On exit of the + contextmanager, call the provided 'on_exit' function and pass a dictionary of + all metrics set during the lifetime of the contextmanager. + """ + self._on_exit = on_exit + self._metrics: Dict[str, Any] = {} + self._level = 0 + + def __enter__(self) -> "MetricsContext": + """ + Initialize metrics recording. + """ + if self._level == 0: + # In case of recursion, track at the outermost context. + self._metrics = {} + + self._level += 1 + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + _traceback: Any, + ) -> None: + """ + At exit, call the provided on_exit function. + """ + self._level -= 1 + assert self._level >= 0 + if self._level == 0: + self._on_exit(self._metrics) + + def in_progress(self) -> bool: + """ + True if we've entered the context. + """ + return self._level > 0 + + def increment(self, metric: str, value: int) -> None: + """ + Increment a metric by a given amount. + """ + if self._level == 0: + raise RuntimeError(f"Cannot increment {metric} outside of a MetricsContext") + if metric not in self._metrics: + self._metrics[metric] = 0 + self._metrics[metric] += value + + def set(self, metric: str, value: Any) -> None: + """ + Set a metric to a given value. Raises if the metric has been assigned previously + in the current context. + """ + if self._level == 0: + raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext") + if metric in self._metrics: + raise RuntimeError( + f"Metric '{metric}' has already been set in the current context" + ) + self._metrics[metric] = value + + def update(self, values: Dict[str, Any]) -> None: + """ + Set multiple metrics directly. This method does NOT increment. Raises if any + metric has been assigned previously in the current context. + """ + if self._level == 0: + raise RuntimeError("Cannot update metrics outside of a MetricsContext") + existing = self._metrics.keys() & values.keys() + if existing: + raise RuntimeError( + f"Metric(s) {existing} have already been set in the current context" + ) + self._metrics.update(values) + + def update_outer(self, values: Dict[str, Any]) -> None: + """ + Update, but only when at the outermost context. + """ + if self._level == 0: + raise RuntimeError("Cannot update metrics outside of a MetricsContext") + if self._level == 1: + self.update(values) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 67792433b9c..a667538cc35 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1391,6 +1391,7 @@ class OutputGraph: "OutputGraph.call_user_compiler", phase_name="backend_compile", log_pt2_compile_event=True, + dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us", ): return self._call_user_compiler(gm) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 65095e7daa9..9319ddfebe7 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -43,6 +43,7 @@ from typing import ( DefaultDict, Deque, Dict, + Generator, Iterable, Iterator, KeysView, @@ -71,6 +72,7 @@ from torch._C import ( _push_on_torch_function_stack, ) from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.metrics_context import MetricsContext from torch._guards import Source, TracingContext from torch._subclasses.meta_utils import is_sparse_compressed from torch._utils_internal import ( @@ -139,12 +141,10 @@ log = logging.getLogger(__name__) # profiling compilation time by function compilation_time_metrics: Dict[str, List[float]] = {} -# profiling compilation time by frame phase -frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict( - lambda: collections.defaultdict(float) -) - -codecache_metrics: Counter[str] = collections.Counter() +# This supports calculate_time_spent(), which reports cumulative times +# across the process for any "phase" populated by dynamo_timed. Reset if +# reset_frame_count() is called. +cumulative_time_spent_ns: Dict[str, float] = collections.defaultdict(float) timer_counter = itertools.count() @@ -220,7 +220,7 @@ def increment_frame() -> None: # Note: Called for you by dynamo - you almost never ever want to invoke this yourself. def reset_frame_count() -> None: global curr_frame - frame_phase_timing.clear() + cumulative_time_spent_ns.clear() compilation_time_metrics.clear() curr_frame = 0 @@ -233,25 +233,16 @@ def increment_op_count(cnt: int) -> None: op_count += cnt -# Calculate total time spent so far for each phase +# Get the total time in seconds for each "phase" # For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806} def calculate_time_spent() -> Dict[str, float]: - total_wall_time = 0.0 total_by_key = {} - for timings in frame_phase_timing.values(): - total_wall_time += timings.get( - "entire_frame_compile", timings.get("inductor_compile", 0) - ) - - for key, timing in timings.items(): - if key not in total_by_key: - total_by_key[key] = timing - else: - total_by_key[key] += timing - - if total_by_key: - total_by_key["total_wall_time"] = total_wall_time + for phase, timing in cumulative_time_spent_ns.items(): + total_by_key[phase] = timing / 1e9 + total_by_key["total_wall_time"] = total_by_key.get( + "entire_frame_compile", 0 + ) + total_by_key.get("entire_backward_compile", 0) return total_by_key @@ -270,188 +261,124 @@ def print_time_report() -> None: print(out) -def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None: - frame_phase_timing[key][phase_name] += time_spent +# Use the following singleton to capture and log CompilationMetrics. Entering the context +# manager allocates a new record to be logged when it exits. (You should not need to use +# this directly unless you introduce a new code path where compilation metrics would be +# gathered). While compiling, use the setters or timer in MetricsContext to update fields +# in the current context. For example: +# +# To set a single field once (use overwrite=True to overwrite): +# get_metrics_context().set("metric_name", value) +# +# To set multiple fields at once (use overwrite=True to overwrite): +# get_metrics_context().update({"name1": val1, "name2": val2}) +# +# To increment an integer field: +# get_metrics_context().increment("metric_name", value) +# +# To record execution time, MetricsContext works with dynamo_timed: +# def foo(...): +# # Updates the "metric_us" field. +# with dynamo_timed("metric", dynamo_compile_column_us="metric_us") +# ... +# +_METRICS_CONTEXT: MetricsContext -# Use frame_phase_timing to record remote_cache_time_saved -# This follows the same principles of key as the other frame phase timings, -# but is incremented by FxGraphCache (and later AOTAutogradCache) directly -def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -> None: - key = None - if is_backward: - # Use compile id as the frame key for backwards compilation - key = str(torch._guards.CompileContext.current_compile_id()) - else: - key = str(curr_frame) - # Convert to seconds (as a float) - time_saved = time_saved_ns / 1e9 - _add_time_spent(key, "remote_cache_time_saved", time_saved) - - -# dynamo_timed is a context manager -# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics -# where the key is the functions name. -# For example: -# -# def _foo(...): -# with dynamo_timed("_foo"): -# ... -# -# Would show up as an entry in our timing dict: -# OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])]) -# This is extremely useful for granular debugging. -# -# Although it is tempting to use dynamo_timed as a decorator, please do not. -# In its decorator form it makes cProfile traces less useful as dynamo_timed -# suddenly becomes a bottleneck for lots of function calls (as only one parent -# pointer is recorded). -# -# For a higher-level mode, pass a phase_name into dynamo_timed -# phase_names record an extra record into a separate compilation timing structure, -# one keyed on frame+name rather than function. -# The frame is incremented outside of this function, in def increment_frame() above. -# `fwd_only` is used to identify if this phase or function is only called -# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`. -# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs. +def get_metrics_context() -> MetricsContext: + return _METRICS_CONTEXT @contextmanager def dynamo_timed( key: str, + # TODO(masneral): Deprecate this param. phase_name: Optional[str] = None, - log_pt2_compile_event: bool = False, # Whether or not to log it to internal pt2 compile event + log_pt2_compile_event: bool = False, + # TODO(masnesral): fwd_only is ignored. Remove it. fwd_only: bool = True, -): - chromium_log: ChromiumEventLogger = get_chromium_event_logger() + metadata: Optional[Dict[str, object]] = None, + dynamo_compile_column_us: Optional[str] = None, +) -> Generator[Any, None, None]: + """ + dynamo_timed is a context manager + By wrapping a function in dynamo_timed, we can get a few things: + + 1) Log timings to pt2_compile_events. + 2) Log timings to CompilationMetrics (dynamo_compile). + 3) Chromium events. + 4) Storing a record in compilation_time_metrics + For example: + + def _foo(...): + with dynamo_timed("_foo"): + ... + + Would show up as an entry in our timing dict: + OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])]) + This is extremely useful for granular debugging. + + Although it is tempting to use dynamo_timed as a decorator, please do not. + In its decorator form it makes cProfile traces less useful as dynamo_timed + suddenly becomes a bottleneck for lots of function calls (as only one parent + pointer is recorded). + + Params: + - key: key into compile_time_metrics. If phase_name is not provided, this is + also the event name used for pt2_compile_events logs and chromium events. + - phase_name: Optional override for the event name. + - log_pt2_compile_event: Whether to log a pt2 compile event internally. + - metadata: Extra metadata to put in pt2_compile_events. + - dynamo_compile_column_us: If provided, updates the specified CompilationMetrics + field to be logged to dyname_compile column. We expect all columns to be _us; + therefore, the field name must end with "_us". + """ + # We're standardizing on microseconds for dynamo_compile timings. + if dynamo_compile_column_us is not None: + assert dynamo_compile_column_us.endswith("_us") + + if phase_name: + event_name = phase_name + fn_name = key + else: + event_name = key + fn_name = None + if key not in compilation_time_metrics: compilation_time_metrics[key] = [] - fail_type: Optional[str] = None - fail_reason: Optional[str] = None - time_spent = float("-inf") + event_metadata = {} + if metadata: + event_metadata.update(metadata) + if fn_name: + event_metadata.update({"fn_name": fn_name}) + + chromium_log: ChromiumEventLogger = get_chromium_event_logger() start_ns = time.time_ns() + chromium_log.log_event_start(event_name, start_ns, event_metadata) + try: with torch.profiler.record_function(f"{key} (dynamo_timed)"): - t0 = time.time() - if phase_name: - chromium_log.log_event_start(phase_name, start_ns, {"fn_name": key}) - else: - chromium_log.log_event_start(key, start_ns, {}) yield - time_spent = time.time() - t0 - compilation_time_metrics[key].append(time_spent) - except Exception as e: - fail_type = str(type(e)) - fail_reason = str(e) - raise finally: end_ns = time.time_ns() - # Always log the end event even on exception - if phase_name: - chromium_log.log_event_end( - phase_name, - end_ns, - {}, - start_ns, - log_pt2_compile_event, - ) - else: - chromium_log.log_event_end(key, end_ns, {}, start_ns, log_pt2_compile_event) - # Only record backward compilation metrics if phase_name is not None! - if phase_name: - frame_key = str(curr_frame) - # fwd only compilation stages: entire_frame_compile, backend_compile, aotdispatch. - # use frame_key as time aggregation key. - if fwd_only and fail_type is None: - _add_time_spent(frame_key, phase_name, time_spent) - else: - # fwd + bwd compilation stages: inductor_compile, code_gen. - # use frame_key as time aggregation key for fwd graphs; - # use compile_id as time aggregation key for bwd graphs. - if torch._guards.TracingContext.try_get() is not None: - aot_graph_name = str( - torch._guards.TracingContext.get().aot_graph_name - ) - if ( - "forward" in aot_graph_name or "inference" in aot_graph_name - ) and fail_type is None: - _add_time_spent(frame_key, phase_name, time_spent) - elif "backward" in aot_graph_name: - compile_id = str( - torch._guards.CompileContext.current_compile_id() - ) - if fail_type is None: - _add_time_spent(compile_id, phase_name, time_spent) - - # log backward compilation metrics at the end of `inductor_compile` of bwd graph, - # one record for one bwd graph. - if phase_name == "inductor_compile": - if fail_type is None: - inductor_compile_time = frame_phase_timing[ - compile_id - ].get("inductor_compile", None) - code_gen_time = frame_phase_timing[compile_id].get( - "code_gen", None - ) - remote_cache_time_saved = frame_phase_timing[ - compile_id - ].get("remote_cache_time_saved", None) - remote_fx_graph_cache_get_time = frame_phase_timing[ - compile_id - ].get("remote_fx_graph_cache_get", None) - remote_fx_graph_cache_put_time = frame_phase_timing[ - compile_id - ].get("remote_fx_graph_cache_put", None) - else: - inductor_compile_time = None - code_gen_time = None - remote_cache_time_saved = None - remote_fx_graph_cache_get_time = None - remote_fx_graph_cache_put_time = None - structured_logging_overhead_s = ( - torch._logging.get_structured_logging_overhead() - ) - metrics = CompilationMetrics( - compile_id=compile_id, - inductor_compile_time_s=inductor_compile_time, - code_gen_time_s=code_gen_time, - fail_type=fail_type, - fail_reason=fail_reason, - remote_cache_time_saved_s=remote_cache_time_saved, - structured_logging_overhead_s=structured_logging_overhead_s, - is_forward=False, # is_forward - num_triton_bundles=codecache_metrics.get( - "num_triton_bundles", None - ), - remote_fx_graph_cache_get_time_ms=to_int_ms( - remote_fx_graph_cache_get_time - ), - remote_fx_graph_cache_put_time_ms=to_int_ms( - remote_fx_graph_cache_put_time - ), - start_time_us=start_ns // 1000, - duration_us=(end_ns - start_ns) // 1000, - inductor_cumulative_compile_time_us=to_int_us( - inductor_compile_time - ), - inductor_code_gen_cumulative_compile_time_us=to_int_us( - code_gen_time - ), - distributed_ephemeral_timeout_us=to_int_us( - remote_cache_time_saved - ), # TODO: instrument more accurately - structured_logging_overhead_us=to_int_us( - structured_logging_overhead_s - ), - remote_fx_graph_cache_get_time_us=to_int_us( - remote_fx_graph_cache_get_time - ), - remote_fx_graph_cache_put_time_us=to_int_us( - remote_fx_graph_cache_put_time - ), - ) - record_compilation_metrics(metrics) + time_spent_ns = end_ns - start_ns + compilation_time_metrics[key].append(time_spent_ns / 1e9) + chromium_log.log_event_end( + event_name, end_ns, {}, start_ns, log_pt2_compile_event + ) + if dynamo_compile_column_us: + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + metrics_context.increment( + dynamo_compile_column_us, time_spent_ns // 1000 + ) + # TODO: the events that we capture in calculate_time_spent() seem a little + # arbitrary. Currently, it's only those fields that are present in + # CompilationMetrics (but note that we accumulate by the associated event + # name, not the field name in CompilationMetrics). Do we want to keep it + # this way? + cumulative_time_spent_ns[event_name] += time_spent_ns @overload @@ -866,6 +793,11 @@ def to_int_us(v: Optional[float]) -> Optional[int]: return None if v is None else int(v * 1_000_000) +# Version field added to every log. Increment to make it easier to distinguish new +# vs. old entries when you make a substantive change to how the logs are populated. +LOG_FORMAT_VERSION = 2 + + @dataclasses.dataclass class CompilationMetrics: compile_id: Optional[str] = None @@ -913,15 +845,18 @@ class CompilationMetrics: aot_autograd_cumulative_compile_time_us: Optional[int] = None inductor_cumulative_compile_time_us: Optional[int] = None inductor_code_gen_cumulative_compile_time_us: Optional[int] = None - triton_compile_time_us: Optional[int] = None - runtime_cudagraphify_time_us: Optional[int] = None - runtime_triton_autotune_time_us: Optional[int] = None + triton_compile_time_us: Optional[int] = None # TODO: instrument + runtime_cudagraphify_time_us: Optional[int] = None # TODO: instrument + runtime_triton_autotune_time_us: Optional[int] = None # TODO: instrument dynamo_compile_time_before_restart_us: Optional[int] = None - cuda_synchronize_time_us: Optional[int] = None + cuda_synchronize_time_us: Optional[int] = None # TODO: instrument distributed_ephemeral_timeout_us: Optional[int] = None structured_logging_overhead_us: Optional[int] = None remote_fx_graph_cache_get_time_us: Optional[int] = None remote_fx_graph_cache_put_time_us: Optional[int] = None + backward_cumulative_compile_time_us: Optional[int] = None + end_time_us: Optional[int] = None + log_format_version: int = LOG_FORMAT_VERSION DEFAULT_COMPILATION_METRICS_LIMIT = 64 @@ -969,8 +904,32 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics): ) -def record_compilation_metrics(compilation_metrics: CompilationMetrics): - global _compilation_metrics +def record_compilation_metrics(metrics: Dict[str, Any]): + # TODO: Temporary; populate legacy fields from their replacements. + # Remove when we decide we can really deprecate them. + def us_to_s(field): + metric = metrics.get(field, None) + return metric / 1e6 if metric is not None else None + + def us_to_ms(field): + metric = metrics.get(field, None) + return metric // 1000 if metric is not None else None + + legacy_metrics = { + "entire_frame_compile_time_s": us_to_s("dynamo_cumulative_compile_time_us"), + "backend_compile_time_s": us_to_s("aot_autograd_cumulative_compile_time_us"), + "inductor_compile_time_s": us_to_s("inductor_cumulative_compile_time_us"), + "code_gen_time_s": us_to_s("inductor_code_gen_cumulative_compile_time_us"), + "remote_cache_time_saved_s": us_to_s("distributed_ephemeral_timeout_us"), + "remote_fx_graph_cache_get_time_ms": us_to_ms( + "remote_fx_graph_cache_get_time_us" + ), + "remote_fx_graph_cache_put_time_ms": us_to_ms( + "remote_fx_graph_cache_put_time_us" + ), + } + + compilation_metrics = CompilationMetrics(**{**metrics, **legacy_metrics}) _compilation_metrics.append(compilation_metrics) if compilation_metrics.is_forward: name = "compilation_metrics" @@ -979,10 +938,7 @@ def record_compilation_metrics(compilation_metrics: CompilationMetrics): name = "bwd_compilation_metrics" torch._logging.trace_structured( name, - lambda: { - k: list(v) if isinstance(v, set) else v - for k, v in dataclasses.asdict(compilation_metrics).items() - }, + lambda: {k: list(v) if isinstance(v, set) else v for k, v in metrics.items()}, # NB: Because compilation metrics *includes* the logging overhead time, # we can't both *measure* the logging overhead of compilation metrics # without making it inconsistent with compilation metrics itself, so @@ -993,6 +949,10 @@ def record_compilation_metrics(compilation_metrics: CompilationMetrics): log_compilation_event(compilation_metrics) +# record_compilation_metrics is called by the singleton MetricsContext exit handler. +_METRICS_CONTEXT = MetricsContext(on_exit=record_compilation_metrics) + + def set_compilation_metrics_limit(new_size: int) -> None: global _compilation_metrics while len(_compilation_metrics) > new_size: diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index bb674088293..13d7a1460d5 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -632,7 +632,7 @@ class AOTAutogradCache: ) # TODO: should we use the same field for remote cache time saved for both # FXGraphCache and AOTAutogradCache? - # add_remote_cache_time_saved(time_saved_ns, is_backward=False) + # get_metrics_context().increment(...) if ( ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( time_saved_ns diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 206bab56924..f421f30cb44 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -10,6 +10,7 @@ import builtins import collections import itertools import pprint +import time from contextlib import nullcontext from dataclasses import dataclass, field from functools import wraps @@ -18,6 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.utils.dlpack from torch import Tensor +from torch._dynamo.utils import dynamo_timed, get_metrics_context, to_int_us from torch._guards import ( compile_context, CompileContext, @@ -2002,17 +2004,49 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa context = torch._C._DisableAutocast if disable_amp else nullcontext with tracing(saved_context), compile_context( saved_compile_context - ), context(), track_graph_compiling(aot_config, "backward"): - CompiledFunction.compiled_bw = aot_config.bw_compiler( - bw_module, placeholder_list - ) - # Maybe save cache entry - if try_save_cache_entry is not None: - try_save_cache_entry( - CompiledFunction.compiled_bw, - fw_metadata, - aot_config, + ), context(), track_graph_compiling( + aot_config, "backward" + ), get_metrics_context(), dynamo_timed( + "backward._backward_impl", + phase_name="entire_backward_compile", + dynamo_compile_column_us="backward_cumulative_compile_time_us", + ): + fail_type: Optional[str] = None + fail_reason: Optional[str] = None + start_ns = time.time_ns() + try: + CompiledFunction.compiled_bw = aot_config.bw_compiler( + bw_module, placeholder_list ) + # Maybe save cache entry + if try_save_cache_entry is not None: + try_save_cache_entry( + CompiledFunction.compiled_bw, + fw_metadata, + aot_config, + ) + except Exception as e: + # TODO(masnesral): Populating the exception info should be automatic. + fail_type = type(e).__qualname__ + fail_reason = str(e) + finally: + # TODO(masnesral): Populating time fields should be automatic. + end_ns = time.time_ns() + metrics = { + "compile_id": str( + torch._guards.CompileContext.current_compile_id() + ), + "fail_type": fail_type, + "fail_reason": fail_reason, + "is_forward": False, + "start_time_us": start_ns // 1000, + "end_time_us": end_ns // 1000, + "duration_us": (end_ns - start_ns) // 1000, + "structured_logging_overhead_us": to_int_us( + torch._logging.get_structured_logging_overhead(), + ), + } + get_metrics_context().update_outer(metrics) if ( torch._functorch.config.donated_buffer diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 2371ae38bfb..821d67514a2 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -54,11 +54,10 @@ import torch import torch.distributed as dist from torch import SymInt, Tensor from torch._dynamo.utils import ( - add_remote_cache_time_saved, - codecache_metrics, counters, dynamo_timed, get_chromium_event_logger, + get_metrics_context, ) from torch._inductor import config, exc, metrics from torch._inductor.codegen.cuda import cuda_env @@ -1152,7 +1151,7 @@ class FxGraphCache: "inductor_compile", cached_kernel_names=meta.cached_kernel_names ) if len(meta.cached_kernel_names) > 0: - codecache_metrics["num_triton_bundles"] += 1 + get_metrics_context().increment("num_triton_bundles", 1) inductor_meta = autotune_cache.inductor_meta_from_config() AutotuneCacheBundler.begin_compile(inductor_meta, code=code) @@ -1449,7 +1448,9 @@ class FxGraphCache: if (time_saved_ns := compiled_graph._time_taken_ns) is not None: cache_info["time_saved_ns"] = time_saved_ns - add_remote_cache_time_saved(time_saved_ns, is_backward) + get_metrics_context().increment( + "distributed_ephemeral_timeout_us", time_saved_ns // 1000 + ) if ( ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( time_saved_ns diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 1cd8d09043a..cd997c30e91 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -567,7 +567,7 @@ def compile_fx_inner( "compile_fx_inner", phase_name="inductor_compile", log_pt2_compile_event=True, - fwd_only=False, + dynamo_compile_column_us="inductor_cumulative_compile_time_us", ) ) # NB: Why is this the dynamo_compile counter? The rule here is that diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index fb96fb7c69b..e89897b87ad 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1949,7 +1949,7 @@ class GraphLowering(torch.fx.Interpreter): "GraphLowering.compile_to_module", phase_name="code_gen", log_pt2_compile_event=True, - fwd_only=False, + dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us", ): return self._compile_to_module() diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index d0359950064..4e53b920a42 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -45,14 +45,14 @@ remote_fx_cache_get_timed = functools.partial( "FbRemoteFxGraphCache.get", phase_name="remote_fx_graph_cache_get", log_pt2_compile_event=False, - fwd_only=False, + dynamo_compile_column_us="remote_fx_graph_cache_get_time_us", ) remote_fx_cache_put_timed = functools.partial( dynamo_timed, "FbRemoteFxGraphCache.put", phase_name="remote_fx_graph_cache_put", log_pt2_compile_event=False, - fwd_only=False, + dynamo_compile_column_us="remote_fx_graph_cache_put_time_us", ) diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 0f4ed04ebe5..a2a39e31f15 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -135,5 +135,11 @@ try: except AttributeError: # Compile workers only have a mock version of torch @contextlib.contextmanager - def dynamo_timed(key, phase_name=None, fwd_only=True): + def dynamo_timed( + key, + phase_name=None, + fwd_only=True, + metadata=None, + dynamo_compile_column_us=None, + ): yield diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 70bbb27bfa2..a31ea0c198c 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1099,7 +1099,6 @@ class LazyString: structured_logging_overhead: Dict[str, float] = defaultdict(float) -# Same principle as add_remote_cache_time_saved, but do it for structured logging def add_structured_logging_overhead(time_spent: float) -> None: global structured_logging_overhead key = None