pytorch/torch/distributed/fsdp/_limiter_utils.py
Andrew Gu 9308cefbdf [FSDP()][8/N] Refactor limiter's _FreeEventQueue (#87922)
This PR is easy. It just moves `_FreeEventQueue` into its own file `_limiter_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87922
Approved by: https://github.com/rohan-varma, https://github.com/mrshenli
2022-10-31 16:45:24 +00:00

34 lines
1.1 KiB
Python

import collections
from typing import Deque, Optional
import torch
class _FreeEventQueue:
"""
This tracks all pending frees corresponding to inflight all-gathers. The
queueing pattern is iterative enqueues with a single dequeue per iteration
once the limit ``_max_num_inflight_all_gathers`` is reached.
"""
def __init__(self) -> None:
self._queue: Deque[torch.cuda.Event] = collections.deque()
self._max_num_inflight_all_gathers = 2 # empirically chosen
def enqueue(self, free_event: torch.cuda.Event) -> None:
"""Enqueues a free event."""
self._queue.append(free_event)
def dequeue_if_needed(self) -> Optional[torch.cuda.Event]:
"""Dequeues a single event if the limit is reached."""
if len(self._queue) >= self._max_num_inflight_all_gathers:
return self._dequeue()
return None
def _dequeue(self) -> Optional[torch.cuda.Event]:
"""Dequeues a free event if possible."""
if self._queue:
event = self._queue.popleft()
return event
return None