mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145162 Approved by: https://github.com/bobrenjc93
34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
import collections
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
class _FreeEventQueue:
|
|
"""
|
|
This tracks all pending frees corresponding to inflight all-gathers. The
|
|
queueing pattern is iterative enqueues with a single dequeue per iteration
|
|
once the limit ``_max_num_inflight_all_gathers`` is reached.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._queue: collections.deque[torch.Event] = collections.deque()
|
|
self._max_num_inflight_all_gathers = 2 # empirically chosen
|
|
|
|
def enqueue(self, free_event: torch.Event) -> None:
|
|
"""Enqueues a free event."""
|
|
self._queue.append(free_event)
|
|
|
|
def dequeue_if_needed(self) -> Optional[torch.Event]:
|
|
"""Dequeues a single event if the limit is reached."""
|
|
if len(self._queue) >= self._max_num_inflight_all_gathers:
|
|
return self._dequeue()
|
|
return None
|
|
|
|
def _dequeue(self) -> Optional[torch.Event]:
|
|
"""Dequeues a free event if possible."""
|
|
if self._queue:
|
|
event = self._queue.popleft()
|
|
return event
|
|
return None
|