mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[logging] Set compile_id in the CachingAutotuner during compilation so we have it for dynamo_timed logging (#148693)
Summary: This is a simpler alternative to https://github.com/pytorch/pytorch/pull/146455, where we can stick the compileId (and forward/backward bool) in the CachingAutotuner so that we have it for logging `benchmark_all_configs`. Recall that the first attempt put the compileId in the inductor_meta and that interfered with caching. Test Plan: `python benchmarks/dynamo/torchbench.py --performance --training --amp --backend inductor --device cuda --print-compilation-time --repeat 5 --cold-start-latency --only nanogpt` * tlparse: https://fburl.com/e71yn6uc * dynamo_compile: https://fburl.com/scuba/dynamo_compile/sandbox/4ageghhv * pt2_compile_events: https://fburl.com/scuba/pt2_compile_events/4fgv1itq Pull Request resolved: https://github.com/pytorch/pytorch/pull/148693 Approved by: https://github.com/eellison
This commit is contained in:
parent
3646d4dbc8
commit
7cdbb913e7
|
|
@ -260,6 +260,7 @@ class StructuredTraceTest(TestCase):
|
||||||
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
|
||||||
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
|
||||||
|
{"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0}
|
||||||
""", # noqa: B950
|
""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -300,6 +300,7 @@ class TestDynamoTimed(TestCase):
|
||||||
'config_suppress_errors': False,
|
'config_suppress_errors': False,
|
||||||
'cuda_synchronize_time_us': None,
|
'cuda_synchronize_time_us': None,
|
||||||
'cuda_version': None,
|
'cuda_version': None,
|
||||||
|
'cudagraph_skip_reason': None,
|
||||||
'distributed_ephemeral_timeout_us': None,
|
'distributed_ephemeral_timeout_us': None,
|
||||||
'duration_us': 0,
|
'duration_us': 0,
|
||||||
'dynamo_compile_time_before_restart_us': 0,
|
'dynamo_compile_time_before_restart_us': 0,
|
||||||
|
|
@ -385,6 +386,7 @@ class TestDynamoTimed(TestCase):
|
||||||
'config_suppress_errors': None,
|
'config_suppress_errors': None,
|
||||||
'cuda_synchronize_time_us': None,
|
'cuda_synchronize_time_us': None,
|
||||||
'cuda_version': None,
|
'cuda_version': None,
|
||||||
|
'cudagraph_skip_reason': None,
|
||||||
'distributed_ephemeral_timeout_us': None,
|
'distributed_ephemeral_timeout_us': None,
|
||||||
'duration_us': 0,
|
'duration_us': 0,
|
||||||
'dynamo_compile_time_before_restart_us': None,
|
'dynamo_compile_time_before_restart_us': None,
|
||||||
|
|
|
||||||
|
|
@ -596,7 +596,7 @@ def dynamo_timed(
|
||||||
dynamo_compile_column_us: Optional[str] = None,
|
dynamo_compile_column_us: Optional[str] = None,
|
||||||
dynamo_compile_runtime_column_us: Optional[str] = None,
|
dynamo_compile_runtime_column_us: Optional[str] = None,
|
||||||
compile_id: Optional[CompileId] = None,
|
compile_id: Optional[CompileId] = None,
|
||||||
is_forward: Optional[bool] = None,
|
is_backward: Optional[bool] = None,
|
||||||
log_waitcounter: bool = False,
|
log_waitcounter: bool = False,
|
||||||
) -> Generator[Any, None, None]:
|
) -> Generator[Any, None, None]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -638,7 +638,8 @@ def dynamo_timed(
|
||||||
- compile_id: In the typical case, this parameter should not be needed. Use to
|
- compile_id: In the typical case, this parameter should not be needed. Use to
|
||||||
supply the compile_id for those cases where we want to log a compile_id where
|
supply the compile_id for those cases where we want to log a compile_id where
|
||||||
it's not naturally available, e.g., for runtime autotuning.
|
it's not naturally available, e.g., for runtime autotuning.
|
||||||
- is_forward: Optionally set an is_forward field for those logging destinations
|
- is_backward: Specify forward/backward directly when not available in a
|
||||||
|
CompileContext, e.g., during runtime autotuning.
|
||||||
that support it.
|
that support it.
|
||||||
- log_waitcounter: If set, we'll log a waitcounter of the form "pytorch.dynamo_timed.{key}"
|
- log_waitcounter: If set, we'll log a waitcounter of the form "pytorch.dynamo_timed.{key}"
|
||||||
"""
|
"""
|
||||||
|
|
@ -664,8 +665,8 @@ def dynamo_timed(
|
||||||
event_metadata.update(metadata)
|
event_metadata.update(metadata)
|
||||||
if fn_name:
|
if fn_name:
|
||||||
event_metadata.update({"fn_name": fn_name})
|
event_metadata.update({"fn_name": fn_name})
|
||||||
if is_forward is not None:
|
if is_backward is not None:
|
||||||
event_metadata.update({"is_backward": not is_forward})
|
event_metadata.update({"is_backward": is_backward})
|
||||||
|
|
||||||
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
|
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
|
||||||
start_ns = time.time_ns()
|
start_ns = time.time_ns()
|
||||||
|
|
@ -707,7 +708,7 @@ def dynamo_timed(
|
||||||
extra={
|
extra={
|
||||||
"compile_id": compile_id,
|
"compile_id": compile_id,
|
||||||
"is_runtime": True,
|
"is_runtime": True,
|
||||||
"is_forward": is_forward,
|
"is_forward": not is_backward,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
cumulative_time_spent_ns[event_name] += time_spent_ns
|
cumulative_time_spent_ns[event_name] += time_spent_ns
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ from torch._inductor.runtime.compile_tasks import (
|
||||||
_worker_compile_triton,
|
_worker_compile_triton,
|
||||||
)
|
)
|
||||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
||||||
|
from torch._inductor.virtualized import V
|
||||||
from torch.hub import _Faketqdm, tqdm
|
from torch.hub import _Faketqdm, tqdm
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
from torch.utils._triton import has_triton_package
|
from torch.utils._triton import has_triton_package
|
||||||
|
|
@ -300,6 +301,10 @@ class AsyncCompile:
|
||||||
)
|
)
|
||||||
is_parallel = self.use_process_pool()
|
is_parallel = self.use_process_pool()
|
||||||
set_feature_use("parallel_compile_post_warmup", is_parallel)
|
set_feature_use("parallel_compile_post_warmup", is_parallel)
|
||||||
|
|
||||||
|
compile_id = torch._guards.CompileContext.current_compile_id()
|
||||||
|
is_backward = getattr(V.graph, "is_backward", False)
|
||||||
|
|
||||||
if is_parallel:
|
if is_parallel:
|
||||||
# We want to support changing these env vars after (and while) the
|
# We want to support changing these env vars after (and while) the
|
||||||
# process pool is running, so pass them to the subprocess to reset.
|
# process pool is running, so pass them to the subprocess to reset.
|
||||||
|
|
@ -322,6 +327,7 @@ class AsyncCompile:
|
||||||
# Now that we've compiled, we should clear the future
|
# Now that we've compiled, we should clear the future
|
||||||
# so it can't be used again
|
# so it can't be used again
|
||||||
CompiledTritonKernels.remove_future(source_code)
|
CompiledTritonKernels.remove_future(source_code)
|
||||||
|
kernel.set_compile_info(compile_id, is_backward)
|
||||||
kernel.precompile(
|
kernel.precompile(
|
||||||
warm_cache_only=False, reload_kernel=reload_kernel_in_parent
|
warm_cache_only=False, reload_kernel=reload_kernel_in_parent
|
||||||
)
|
)
|
||||||
|
|
@ -343,6 +349,7 @@ class AsyncCompile:
|
||||||
start_ns = time_ns()
|
start_ns = time_ns()
|
||||||
_set_triton_ptxas_path()
|
_set_triton_ptxas_path()
|
||||||
kernel = load_kernel()
|
kernel = load_kernel()
|
||||||
|
kernel.set_compile_info(compile_id, is_backward)
|
||||||
kernel.precompile(warm_cache_only=False)
|
kernel.precompile(warm_cache_only=False)
|
||||||
elapsed_us = (time_ns() - start_ns) // 1000
|
elapsed_us = (time_ns() - start_ns) // 1000
|
||||||
get_metrics_context().add_top_n(
|
get_metrics_context().add_top_n(
|
||||||
|
|
|
||||||
|
|
@ -410,7 +410,7 @@ def dynamo_timed_cudagraph(
|
||||||
name,
|
name,
|
||||||
log_pt2_compile_event=True,
|
log_pt2_compile_event=True,
|
||||||
compile_id=compile_id,
|
compile_id=compile_id,
|
||||||
is_forward=mode != CompilationMode.BACKWARD,
|
is_backward=mode == CompilationMode.BACKWARD,
|
||||||
dynamo_compile_runtime_column_us="runtime_cudagraphify_time_us"
|
dynamo_compile_runtime_column_us="runtime_cudagraphify_time_us"
|
||||||
if dynamo_compile
|
if dynamo_compile
|
||||||
else None,
|
else None,
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,8 @@ class NoTritonConfigsError(RuntimeError):
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Container, Hashable, Sequence
|
from collections.abc import Container, Hashable, Sequence
|
||||||
|
|
||||||
|
from torch._guards import CompileId
|
||||||
|
|
||||||
LauncherType = Any
|
LauncherType = Any
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -258,6 +260,16 @@ class CachingAutotuner(KernelInterface):
|
||||||
|
|
||||||
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
|
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
|
||||||
|
|
||||||
|
# Compile-time info included in runtime logginging
|
||||||
|
self.compile_id: Optional[CompileId] = None
|
||||||
|
self.is_backward = False
|
||||||
|
|
||||||
|
def set_compile_info(
|
||||||
|
self, compile_id: Optional[CompileId], is_backward: bool
|
||||||
|
) -> None:
|
||||||
|
self.compile_id = compile_id
|
||||||
|
self.is_backward = is_backward
|
||||||
|
|
||||||
def precompile(
|
def precompile(
|
||||||
self,
|
self,
|
||||||
warm_cache_only=False,
|
warm_cache_only=False,
|
||||||
|
|
@ -731,8 +743,9 @@ class CachingAutotuner(KernelInterface):
|
||||||
"CachingAutotuner.benchmark_all_configs",
|
"CachingAutotuner.benchmark_all_configs",
|
||||||
log_pt2_compile_event=True,
|
log_pt2_compile_event=True,
|
||||||
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
|
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
|
||||||
# TODO(masnesral): Enable this when we figure out how to get the CompileId:
|
dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
|
||||||
# dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
|
compile_id=self.compile_id,
|
||||||
|
is_backward=self.is_backward,
|
||||||
):
|
):
|
||||||
timings = {
|
timings = {
|
||||||
launcher: self.bench(launcher, *args, **kwargs)
|
launcher: self.bench(launcher, *args, **kwargs)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user