Enable strobelight profiling specific compile frame ids using COMPILE_STROBELIGHT_FRAME_FILTER (#147549)

running python test/strobelight/examples/compile_time_profile_example.py
```
strobelight_compile_time_profiler, line 123, 2025-02-20 14:08:08,409, INFO: compile time strobelight profiling enabled
strobelight_compile_time_profiler, line 159, 2025-02-20 14:08:08,409, INFO: Unique sample tag for this run is: 2025-02-20-14:08:081656673devgpu005.nha1.facebook.com
strobelight_compile_time_profiler, line 160, 2025-02-20 14:08:09,124, INFO: URL to access the strobelight profile at the end of the run: https://fburl.com/scuba/pyperf_experimental/on_demand/9felqj0i

strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:12,436, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.*
strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:15,553, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.*
strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:16,170, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.*
strobelight_compile_time_profiler, line 214, 2025-02-20 14:08:16,877, INFO: profiling frame 1/0
strobelight_function_profiler, line 247, 2025-02-20 14:08:19,416, INFO: strobelight run id is: 4015948658689996
strobelight_function_profiler, line 249, 2025-02-20 14:08:21,546, INFO: strobelight profiling running
strobelight_function_profiler, line 289, 2025-02-20 14:08:25,964, INFO: work function took 4.417063233006047 seconds
strobelight_function_profiler, line 230, 2025-02-20 14:08:28,310, INFO: strobelight profiling stopped
strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: Total samples: 119
strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: GraphProfiler (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/73h2f7ur
strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: Icicle view (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/zs06fi9e
strobelight_compile_time_profiler, line 167, 2025-02-20 14:08:44,308, INFO: 1 strobelight success runs out of 1 non-recursive compilation events.

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147549
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #147547
This commit is contained in:
Laith Sakka 2025-02-20 14:33:18 -08:00 committed by PyTorch MergeBot
parent fc095a885c
commit 77d2780657
3 changed files with 49 additions and 4 deletions

View File

@ -7,6 +7,12 @@ if __name__ == "__main__":
# You can pass TORCH_COMPILE_STROBELIGHT=True instead. # You can pass TORCH_COMPILE_STROBELIGHT=True instead.
StrobelightCompileTimeProfiler.enable() StrobelightCompileTimeProfiler.enable()
# You can use the code below to filter what frames to be profiled.
StrobelightCompileTimeProfiler.frame_id_filter = "1/.*"
# StrobelightCompileTimeProfiler.frame_id_filter='0/.*'
# StrobelightCompileTimeProfiler.frame_id_filter='.*'
# You can set env variable COMPILE_STROBELIGHT_FRAME_FILTER to set the filter also.
def fn(x, y, z): def fn(x, y, z):
return x * y + z return x * y + z
@ -18,6 +24,14 @@ if __name__ == "__main__":
# Strobelight will be called only 3 times because dynamo will be disabled after # Strobelight will be called only 3 times because dynamo will be disabled after
# 3rd iteration. # 3rd iteration.
# Frame 0/0
for i in range(3): for i in range(3):
torch._dynamo.reset() torch._dynamo.reset()
work(i) work(i)
@torch.compile(fullgraph=True)
def func4(x):
return x * x
# Frame 1/0
func4(torch.rand(10))

View File

@ -3,6 +3,7 @@
import json import json
import logging import logging
import os import os
import re
import subprocess import subprocess
from datetime import datetime from datetime import datetime
from socket import gethostname from socket import gethostname
@ -84,6 +85,10 @@ class StrobelightCompileTimeProfiler:
ignored_profile_runs: int = 0 ignored_profile_runs: int = 0
inside_profile_compile_time: bool = False inside_profile_compile_time: bool = False
enabled: bool = False enabled: bool = False
# A regex that can be used to filter out what frames to profile. ex: "1/.*"
frame_id_filter: Optional[str] = os.environ.get("COMPILE_STROBELIGHT_FRAME_FILTER")
# A unique identifier that is used as the run_user_name in the strobelight profile to # A unique identifier that is used as the run_user_name in the strobelight profile to
# associate all compile time profiles together. # associate all compile time profiles together.
identifier: Optional[str] = None identifier: Optional[str] = None
@ -103,6 +108,12 @@ class StrobelightCompileTimeProfiler:
float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7)) float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7))
) )
@classmethod
def get_frame(cls) -> str:
from torch._guards import CompileContext
return (str)(CompileContext.current_trace_id())
@classmethod @classmethod
def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None: def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None:
if cls.enabled: if cls.enabled:
@ -164,25 +175,43 @@ class StrobelightCompileTimeProfiler:
def profile_compile_time( def profile_compile_time(
cls, func: Any, phase_name: str, *args: Any, **kwargs: Any cls, func: Any, phase_name: str, *args: Any, **kwargs: Any
) -> Any: ) -> Any:
if not cls.enabled: def skip() -> Any:
return func(*args, **kwargs) return func(*args, **kwargs)
if not cls.enabled:
return skip()
if cls.profiler is None: if cls.profiler is None:
logger.error("profiler is not set") logger.error("profiler is not set")
return return
frame_id = cls.get_frame()
if cls.inside_profile_compile_time: if cls.inside_profile_compile_time:
cls.ignored_profile_runs += 1 cls.ignored_profile_runs += 1
logger.info( logger.info(
"profile_compile_time is requested for phase: %s while already in running phase: %s, recursive call ignored", "profile_compile_time is requested for phase: %s, frame %s, while already in running phase: %s,"
"frame %s, recursive call ignored",
phase_name, phase_name,
frame_id,
cls.current_phase, cls.current_phase,
frame_id,
) )
return func(*args, **kwargs) return skip()
if cls.frame_id_filter is not None:
should_run = re.match(cls.frame_id_filter, frame_id) is not None
if not should_run:
logger.info(
"profiling frame %s is skipped due to frame_id_filter %s",
frame_id,
cls.frame_id_filter,
)
return skip()
cls.inside_profile_compile_time = True cls.inside_profile_compile_time = True
cls.current_phase = phase_name cls.current_phase = phase_name
logger.info("profiling frame %s", frame_id)
work_result = cls.profiler.profile(func, *args, **kwargs) work_result = cls.profiler.profile(func, *args, **kwargs)
if cls.profiler.profile_result is not None: if cls.profiler.profile_result is not None:

View File

@ -91,6 +91,8 @@ def compile_time_strobelight_meta(
if "skip" in kwargs and isinstance(skip := kwargs["skip"], int): if "skip" in kwargs and isinstance(skip := kwargs["skip"], int):
kwargs["skip"] = skip + 1 kwargs["skip"] = skip + 1
# This is not needed but we have it here to avoid having profile_compile_time
# in stack traces when profiling is not enabled.
if not StrobelightCompileTimeProfiler.enabled: if not StrobelightCompileTimeProfiler.enabled:
return function(*args, **kwargs) return function(*args, **kwargs)