pytorch/torch/utils/_appending_byte_serializer.py
PyTorch MergeBot bdf7cb9d9c Revert "[torch/utils][Code Clean] Clean asserts in torch/utils/*.py (#165410)"
This reverts commit e20c9bf288.

Reverted https://github.com/pytorch/pytorch/pull/165410 on behalf of https://github.com/clee2000 due to sorry I'm going to revert this since I want to try to back out some other things that are conflicting with this, there is nothing wrong with this PR, rebasing and resolving the merge conflicts should be enough, sorry for the churn ([comment](https://github.com/pytorch/pytorch/pull/165410#issuecomment-3427532373))
2025-10-21 16:27:54 +00:00

129 lines
3.6 KiB
Python

import base64
import zlib
from collections.abc import Callable, Iterable
from typing import Generic, TypeVar
T = TypeVar("T")
_ENCODING_VERSION: int = 1
__all__ = ["AppendingByteSerializer"]
#######################################
# Helper classes
#######################################
CHECKSUM_DIGEST_SIZE = 4
class BytesWriter:
def __init__(self) -> None:
# Reserve CHECKSUM_DIGEST_SIZE bytes for checksum
self._data = bytearray(CHECKSUM_DIGEST_SIZE)
def write_uint64(self, i: int) -> None:
self._data.extend(i.to_bytes(8, byteorder="big", signed=False))
def write_str(self, s: str) -> None:
payload = base64.b64encode(s.encode("utf-8"))
self.write_bytes(payload)
def write_bytes(self, b: bytes) -> None:
self.write_uint64(len(b))
self._data.extend(b)
def to_bytes(self) -> bytes:
digest = zlib.crc32(self._data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
4, byteorder="big", signed=False
)
assert len(digest) == CHECKSUM_DIGEST_SIZE
self._data[0:CHECKSUM_DIGEST_SIZE] = digest
return bytes(self._data)
class BytesReader:
def __init__(self, data: bytes) -> None:
# Check for data corruption
assert len(data) >= CHECKSUM_DIGEST_SIZE
digest = zlib.crc32(data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
4, byteorder="big", signed=False
)
assert len(digest) == CHECKSUM_DIGEST_SIZE
if data[0:CHECKSUM_DIGEST_SIZE] != digest:
raise RuntimeError(
"Bytes object is corrupted, checksum does not match. "
f"Expected: {data[0:CHECKSUM_DIGEST_SIZE]!r}, Got: {digest!r}"
)
self._data = data
self._i = CHECKSUM_DIGEST_SIZE
def is_finished(self) -> bool:
return len(self._data) == self._i
def read_uint64(self) -> int:
result = int.from_bytes(
self._data[self._i : self._i + 8], byteorder="big", signed=False
)
self._i += 8
return result
def read_str(self) -> str:
return base64.b64decode(self.read_bytes()).decode("utf-8")
def read_bytes(self) -> bytes:
size = self.read_uint64()
result = self._data[self._i : self._i + size]
self._i += size
return result
#######################################
# AppendingByteSerializer
#######################################
class AppendingByteSerializer(Generic[T]):
"""
Provides efficient serialization and deserialization of list of bytes
Note that this does not provide any guarantees around byte order
"""
_serialize_fn: Callable[[BytesWriter, T], None]
_writer: BytesWriter
def __init__(
self,
*,
serialize_fn: Callable[[BytesWriter, T], None],
) -> None:
self._serialize_fn = serialize_fn
self.clear()
def clear(self) -> None:
self._writer = BytesWriter()
# First 8-bytes are for version
self._writer.write_uint64(_ENCODING_VERSION)
def append(self, data: T) -> None:
self._serialize_fn(self._writer, data)
def extend(self, elems: Iterable[T]) -> None:
for elem in elems:
self.append(elem)
def to_bytes(self) -> bytes:
return self._writer.to_bytes()
@staticmethod
def to_list(data: bytes, *, deserialize_fn: Callable[[BytesReader], T]) -> list[T]:
reader = BytesReader(data)
assert reader.read_uint64() == _ENCODING_VERSION
result: list[T] = []
while not reader.is_finished():
result.append(deserialize_fn(reader))
return result