pytorch/torch/distributed/checkpoint/_extension.py
PyTorch MergeBot c6986ca2e1 Revert "[dcp] Add ZStandard transformer (#143360)"
This reverts commit 7b56b039af.

Reverted https://github.com/pytorch/pytorch/pull/143360 on behalf of https://github.com/atalman due to Broke 3.13t builds please test with ciflow/binaries label attached ([comment](https://github.com/pytorch/pytorch/pull/143360#issuecomment-2603433066))
2025-01-21 01:10:16 +00:00

106 lines
3.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
import abc
from collections.abc import Sequence
from typing import IO, Type
# NOTE: everything in this file is experimental, and subject to
# change. Feedback and bug fixes are always welcome.
__all__ = ["Extension", "StreamTransformExtension", "ExtensionRegistry"]
class Extension(abc.ABC):
"""
Extensions provide modular additions to functionality within distributed checkpointing,
which affect the layout or format of the written artifacts. Extensions may be
built into pytorch, or provided externally.
When writing, the caller provides a list of extension instances of the appropriate
type. Each extension can output a descriptor which is used to reconstitute the
extension at read-time.
"""
@staticmethod
@abc.abstractmethod
def registry_name() -> str:
"""
See ExtensionRegistry.from_descriptor_list
"""
@staticmethod
@abc.abstractmethod
def from_descriptor(version: str) -> "Extension":
"""
See ExtensionRegistry.from_descriptor_list
"""
@abc.abstractmethod
def get_descriptor(self) -> str:
"""
Return descriptor name to be included in metadata. The form should be
"extension_name[@local-domain][/version]".
"""
class StreamTransformExtension(Extension):
"""
An extension which performs transformation on a byte stream, such as compression
or encryption.
Implementations should try to be memory friendly and performant. For example, don't
read the whole input, then transform it, and write it back. If at all possible, do it in
chunks. But, don't read/transform/write one byte at a time, either.
"""
@abc.abstractmethod
def transform_to(self, output: IO[bytes]) -> IO[bytes]:
"""
Takes a writeable output stream, and generates a new stream which implements the
output transform. Input data written to the returned stream will be transformed
and written to the `output` argument stream.
"""
@abc.abstractmethod
def transform_from(self, input: IO[bytes]) -> IO[bytes]:
"""
Takes a readable input stream, and generates a new stream which implements the
input transform. When the returned stream is read, data will be read from the
'input' stream, transformed, and returned.
"""
class ExtensionRegistry:
def __init__(self) -> None:
# Populate default registry contents
self.extensions: dict[str, Type[Extension]] = {}
def register(self, cls: Type[Extension]) -> None:
self.extensions[cls.registry_name()] = cls
def from_descriptor_list(self, descriptors: Sequence[str]) -> Sequence[Extension]:
"""
Given a seuquence of descriptor strings as returned by
Extension.get_descriptor at save time, creates a sequence of
Extension instances. The name[@local-domain] preceding the
version number is used to look up an implementation class in
the registry, and the version is passed to the class's
from_descriptor static method. If the registry contains no
match, this will throw ValueError. If the from_descriptor
method raises an exception, that will pass through to the
caller.
"""
def from_descriptor(desc: str) -> Extension:
name, _, version = desc.partition("/")
if version is None:
version = 0
ext = self.extensions.get(name)
if not ext:
raise ValueError(f"Unknown extension {name=}")
return ext.from_descriptor(version)
return [from_descriptor(desc) for desc in descriptors]