pytorch/torch/mps/profiler.py
Bob Ren fd20889d0b Add type annotations to MPS profiler utilities (#163486)
## Summary
- drop the local mypy allow-untyped-defs escape hatch in the MPS profiler helpers
- annotate the context managers and bool helpers so they type-check cleanly

## Testing
- python -m mypy torch/mps/profiler.py --config-file mypy-strict.ini

------
https://chatgpt.com/codex/tasks/task_e_68d0ce4df2e483268d06673b65ef7745
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163486
Approved by: https://github.com/Skylion007
2025-09-27 23:00:53 +00:00

101 lines
3.5 KiB
Python

import contextlib
from collections.abc import Iterator
from typing import Literal
import torch
__all__ = [
"start",
"stop",
"profile",
"metal_capture",
"is_metal_capture_enabled",
"is_capturing_metal",
]
ProfilerMode = Literal["interval", "event", "interval,event"]
def start(mode: ProfilerMode = "interval", wait_until_completed: bool = False) -> None:
r"""Start OS Signpost tracing from MPS backend.
The generated OS Signposts could be recorded and viewed in
XCode Instruments Logging tool.
Args:
mode(str): OS Signpost tracing mode could be "interval", "event",
or both "interval,event".
The interval mode traces the duration of execution of the operations,
whereas event mode marks the completion of executions.
See document `Recording Performance Data`_ for more info.
wait_until_completed(bool): Waits until the MPS Stream complete
executing each encoded GPU operation. This helps generating single
dispatches on the trace's timeline.
Note that enabling this option would affect the performance negatively.
.. _Recording Performance Data:
https://developer.apple.com/documentation/os/logging/recording_performance_data
"""
mode_normalized = mode.lower().replace(" ", "")
torch._C._mps_profilerStartTrace( # type: ignore[attr-defined]
mode_normalized, wait_until_completed
)
def stop() -> None:
r"""Stops generating OS Signpost tracing from MPS backend."""
torch._C._mps_profilerStopTrace() # type: ignore[attr-defined]
@contextlib.contextmanager
def profile(
mode: ProfilerMode = "interval", wait_until_completed: bool = False
) -> Iterator[None]:
r"""Context Manager to enabling generating OS Signpost tracing from MPS backend.
Args:
mode(str): OS Signpost tracing mode could be "interval", "event",
or both "interval,event".
The interval mode traces the duration of execution of the operations,
whereas event mode marks the completion of executions.
See document `Recording Performance Data`_ for more info.
wait_until_completed(bool): Waits until the MPS Stream complete
executing each encoded GPU operation. This helps generating single
dispatches on the trace's timeline.
Note that enabling this option would affect the performance negatively.
.. _Recording Performance Data:
https://developer.apple.com/documentation/os/logging/recording_performance_data
"""
try:
start(mode, wait_until_completed)
yield
finally:
stop()
def is_metal_capture_enabled() -> bool:
"""Checks if `metal_capture` context manager is usable
To enable metal capture, set MTL_CAPTURE_ENABLED envvar
"""
return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined, no-any-return]
def is_capturing_metal() -> bool:
"""Checks if metal capture is in progress"""
return torch._C._mps_isCapturing() # type: ignore[attr-defined, no-any-return]
@contextlib.contextmanager
def metal_capture(fname: str) -> Iterator[None]:
"""Context manager that enables capturing of Metal calls into gputrace"""
try:
torch._C._mps_startCapture(fname) # type: ignore[attr-defined]
yield
# Drain all the work that were enqueued during the context call
torch.mps.synchronize()
finally:
torch._C._mps_stopCapture() # type: ignore[attr-defined]