pytorch/torch/distributed/checkpoint/stateful.py
Aaron Gokaslan 6b1211df29 [BE]: Backport runtime_checkable perf improvements/behavior from 3.12 (#155130)
Backports some behavior changes and performance improvements with runtime_checkable in 3.12 to older versions of Python. Should be free performance improvement on typing checking protocols since everything works on Python 3.12.

The difference between the two versions of runtime_checkable is [these lines](40e22ebb2c/src/typing_extensions.py (L800-L823)).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155130
Approved by: https://github.com/rec, https://github.com/aorenste
2025-06-06 13:28:05 +00:00

43 lines
1.0 KiB
Python

from typing import Any, TypeVar
from typing_extensions import Protocol, runtime_checkable
__all__ = ["Stateful", "StatefulT"]
@runtime_checkable
class Stateful(Protocol):
"""
Stateful protocol for objects that can be checkpointed and restored.
"""
def state_dict(self) -> dict[str, Any]:
"""
Objects should return their state_dict representation as a dictionary.
The output of this function will be checkpointed, and later restored in
`load_state_dict()`.
.. warning::
Because of the inplace nature of restoring a checkpoint, this function
is also called during `torch.distributed.checkpoint.load`.
Returns:
Dict: The objects state dict
"""
...
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""
Restore the object's state from the provided state_dict.
Args:
state_dict: The state dict to restore from
"""
...
StatefulT = TypeVar("StatefulT", bound=Stateful)