pytorch/torch/mps/event.py
Ramin Azarmehr cdfd0ea162 [MPS] Introduce torch.mps.Event() APIs (#102121)
- Implement `MPSEventPool` to recycle events.
- Implement python bindings with `torch.mps.Event` class using the MPSEventPool backend. The current member functions of the Event class are `record()`, `wait()`, `synchronize()`, `query()`, and `elapsed_time()`.
- Add API to measure elapsed time between two event recordings.
- Added documentation for Event class to `mps.rst`.
- Added test case to `test_mps.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102121
Approved by: https://github.com/albanD, https://github.com/kulinseth
2023-08-08 03:45:45 +00:00

46 lines
1.6 KiB
Python

import torch
class Event:
r"""Wrapper around an MPS event.
MPS events are synchronization markers that can be used to monitor the
device's progress, to accurately measure timing, and to synchronize MPS streams.
Args:
enable_timing (bool, optional): indicates if the event should measure time
(default: ``False``)
"""
def __init__(self, enable_timing=False):
self.__eventId = torch._C._mps_acquireEvent(enable_timing)
def __del__(self):
# checks if torch._C is already destroyed
if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0:
torch._C._mps_releaseEvent(self.__eventId)
def record(self):
r"""Records the event in the default stream."""
torch._C._mps_recordEvent(self.__eventId)
def wait(self):
r"""Makes all future work submitted to the default stream wait for this event."""
torch._C._mps_waitForEvent(self.__eventId)
def query(self):
r"""Returns True if all work currently captured by event has completed."""
return torch._C._mps_queryEvent(self.__eventId)
def synchronize(self):
r"""Waits until the completion of all work currently captured in this event.
This prevents the CPU thread from proceeding until the event completes.
"""
torch._C._mps_synchronizeEvent(self.__eventId)
def elapsed_time(self, end_event):
r"""Returns the time elapsed in milliseconds after the event was
recorded and before the end_event was recorded.
"""
return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId)