mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable strobelight profiling specific compile frame ids using COMPILE_STROBELIGHT_FRAME_FILTER (#147549)
running python test/strobelight/examples/compile_time_profile_example.py ``` strobelight_compile_time_profiler, line 123, 2025-02-20 14:08:08,409, INFO: compile time strobelight profiling enabled strobelight_compile_time_profiler, line 159, 2025-02-20 14:08:08,409, INFO: Unique sample tag for this run is: 2025-02-20-14:08:081656673devgpu005.nha1.facebook.com strobelight_compile_time_profiler, line 160, 2025-02-20 14:08:09,124, INFO: URL to access the strobelight profile at the end of the run: https://fburl.com/scuba/pyperf_experimental/on_demand/9felqj0i strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:12,436, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.* strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:15,553, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.* strobelight_compile_time_profiler, line 205, 2025-02-20 14:08:16,170, INFO: profiling frame 0/0 is skipped due to frame_id_filter 1/.* strobelight_compile_time_profiler, line 214, 2025-02-20 14:08:16,877, INFO: profiling frame 1/0 strobelight_function_profiler, line 247, 2025-02-20 14:08:19,416, INFO: strobelight run id is: 4015948658689996 strobelight_function_profiler, line 249, 2025-02-20 14:08:21,546, INFO: strobelight profiling running strobelight_function_profiler, line 289, 2025-02-20 14:08:25,964, INFO: work function took 4.417063233006047 seconds strobelight_function_profiler, line 230, 2025-02-20 14:08:28,310, INFO: strobelight profiling stopped strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: Total samples: 119 strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: GraphProfiler (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/73h2f7ur strobelight_function_profiler, line 221, 2025-02-20 14:08:44,308, INFO: Icicle view (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/zs06fi9e strobelight_compile_time_profiler, line 167, 2025-02-20 14:08:44,308, INFO: 1 strobelight success runs out of 1 non-recursive compilation events. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/147549 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #147547
This commit is contained in:
parent
fc095a885c
commit
77d2780657
|
|
@ -7,6 +7,12 @@ if __name__ == "__main__":
|
||||||
# You can pass TORCH_COMPILE_STROBELIGHT=True instead.
|
# You can pass TORCH_COMPILE_STROBELIGHT=True instead.
|
||||||
StrobelightCompileTimeProfiler.enable()
|
StrobelightCompileTimeProfiler.enable()
|
||||||
|
|
||||||
|
# You can use the code below to filter what frames to be profiled.
|
||||||
|
StrobelightCompileTimeProfiler.frame_id_filter = "1/.*"
|
||||||
|
# StrobelightCompileTimeProfiler.frame_id_filter='0/.*'
|
||||||
|
# StrobelightCompileTimeProfiler.frame_id_filter='.*'
|
||||||
|
# You can set env variable COMPILE_STROBELIGHT_FRAME_FILTER to set the filter also.
|
||||||
|
|
||||||
def fn(x, y, z):
|
def fn(x, y, z):
|
||||||
return x * y + z
|
return x * y + z
|
||||||
|
|
||||||
|
|
@ -18,6 +24,14 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Strobelight will be called only 3 times because dynamo will be disabled after
|
# Strobelight will be called only 3 times because dynamo will be disabled after
|
||||||
# 3rd iteration.
|
# 3rd iteration.
|
||||||
|
# Frame 0/0
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
work(i)
|
work(i)
|
||||||
|
|
||||||
|
@torch.compile(fullgraph=True)
|
||||||
|
def func4(x):
|
||||||
|
return x * x
|
||||||
|
|
||||||
|
# Frame 1/0
|
||||||
|
func4(torch.rand(10))
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from socket import gethostname
|
from socket import gethostname
|
||||||
|
|
@ -84,6 +85,10 @@ class StrobelightCompileTimeProfiler:
|
||||||
ignored_profile_runs: int = 0
|
ignored_profile_runs: int = 0
|
||||||
inside_profile_compile_time: bool = False
|
inside_profile_compile_time: bool = False
|
||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
|
|
||||||
|
# A regex that can be used to filter out what frames to profile. ex: "1/.*"
|
||||||
|
frame_id_filter: Optional[str] = os.environ.get("COMPILE_STROBELIGHT_FRAME_FILTER")
|
||||||
|
|
||||||
# A unique identifier that is used as the run_user_name in the strobelight profile to
|
# A unique identifier that is used as the run_user_name in the strobelight profile to
|
||||||
# associate all compile time profiles together.
|
# associate all compile time profiles together.
|
||||||
identifier: Optional[str] = None
|
identifier: Optional[str] = None
|
||||||
|
|
@ -103,6 +108,12 @@ class StrobelightCompileTimeProfiler:
|
||||||
float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7))
|
float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_frame(cls) -> str:
|
||||||
|
from torch._guards import CompileContext
|
||||||
|
|
||||||
|
return (str)(CompileContext.current_trace_id())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None:
|
def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None:
|
||||||
if cls.enabled:
|
if cls.enabled:
|
||||||
|
|
@ -164,25 +175,43 @@ class StrobelightCompileTimeProfiler:
|
||||||
def profile_compile_time(
|
def profile_compile_time(
|
||||||
cls, func: Any, phase_name: str, *args: Any, **kwargs: Any
|
cls, func: Any, phase_name: str, *args: Any, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if not cls.enabled:
|
def skip() -> Any:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
if not cls.enabled:
|
||||||
|
return skip()
|
||||||
|
|
||||||
if cls.profiler is None:
|
if cls.profiler is None:
|
||||||
logger.error("profiler is not set")
|
logger.error("profiler is not set")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
frame_id = cls.get_frame()
|
||||||
|
|
||||||
if cls.inside_profile_compile_time:
|
if cls.inside_profile_compile_time:
|
||||||
cls.ignored_profile_runs += 1
|
cls.ignored_profile_runs += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
"profile_compile_time is requested for phase: %s while already in running phase: %s, recursive call ignored",
|
"profile_compile_time is requested for phase: %s, frame %s, while already in running phase: %s,"
|
||||||
|
"frame %s, recursive call ignored",
|
||||||
phase_name,
|
phase_name,
|
||||||
|
frame_id,
|
||||||
cls.current_phase,
|
cls.current_phase,
|
||||||
|
frame_id,
|
||||||
)
|
)
|
||||||
return func(*args, **kwargs)
|
return skip()
|
||||||
|
|
||||||
|
if cls.frame_id_filter is not None:
|
||||||
|
should_run = re.match(cls.frame_id_filter, frame_id) is not None
|
||||||
|
if not should_run:
|
||||||
|
logger.info(
|
||||||
|
"profiling frame %s is skipped due to frame_id_filter %s",
|
||||||
|
frame_id,
|
||||||
|
cls.frame_id_filter,
|
||||||
|
)
|
||||||
|
return skip()
|
||||||
|
|
||||||
cls.inside_profile_compile_time = True
|
cls.inside_profile_compile_time = True
|
||||||
cls.current_phase = phase_name
|
cls.current_phase = phase_name
|
||||||
|
logger.info("profiling frame %s", frame_id)
|
||||||
work_result = cls.profiler.profile(func, *args, **kwargs)
|
work_result = cls.profiler.profile(func, *args, **kwargs)
|
||||||
|
|
||||||
if cls.profiler.profile_result is not None:
|
if cls.profiler.profile_result is not None:
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,8 @@ def compile_time_strobelight_meta(
|
||||||
if "skip" in kwargs and isinstance(skip := kwargs["skip"], int):
|
if "skip" in kwargs and isinstance(skip := kwargs["skip"], int):
|
||||||
kwargs["skip"] = skip + 1
|
kwargs["skip"] = skip + 1
|
||||||
|
|
||||||
|
# This is not needed but we have it here to avoid having profile_compile_time
|
||||||
|
# in stack traces when profiling is not enabled.
|
||||||
if not StrobelightCompileTimeProfiler.enabled:
|
if not StrobelightCompileTimeProfiler.enabled:
|
||||||
return function(*args, **kwargs)
|
return function(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user