Fix type stubs for SymmetricMemory (#146310)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146310
Approved by: https://github.com/yifuwang
This commit is contained in:
Luca Wehrstedt 2025-02-21 16:43:59 +00:00 committed by PyTorch MergeBot
parent fd8ae1aa04
commit 5ed1e23e3a

View File

@ -689,17 +689,48 @@ class _SymmetricMemory:
dtype: torch.dtype,
storage_offset: int | None = 0,
) -> torch.Tensor: ...
def barrier(self, channel: int = 0) -> None: ...
def put_signal(self, dst_rank: int, channel: int = 0) -> None: ...
def wait_signal(self, src_rank: int, channel: int = 0) -> None: ...
def get_signal_pad(
self,
rank: int,
sizes: torch.types._size = [],
dtype: torch.dtype | None = None,
storage_offset: int | None = 0,
) -> torch.Tensor: ...
def barrier(self, channel: int = 0, timeout_ms: int = 0) -> None: ...
def put_signal(
self,
dst_rank: int,
channel: int = 0,
timeout_ms: int = 0,
) -> None: ...
def wait_signal(
self,
src_rank: int,
channel: int = 0,
timeout_ms: int = 0,
) -> None: ...
@staticmethod
def memset32(
tensor: torch.Tensor, offset: int, val: int, count: int
tensor: torch.Tensor, offset: int, val: int, count: int = 1
) -> torch.Tensor: ...
@staticmethod
def stream_write_value32(
tensor: torch.Tensor, offset: int, val: int
) -> torch.Tensor: ...
@property
def buffer_ptrs(self) -> list[int]: ...
@property
def buffer_ptrs_dev(self) -> int: ...
@property
def signal_pad_ptrs(self) -> list[int]: ...
@property
def signal_pad_ptrs_dev(self) -> int: ...
@property
def multicast_ptr(self) -> int: ...
@property
def buffer_size(self) -> int: ...
@property
def signal_pad_size(self) -> int: ...
class ProcessGroupXCCL(Backend):
def __init__(