mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Backports some behavior changes and performance improvements with runtime_checkable in 3.12 to older versions of Python. Should be free performance improvement on typing checking protocols since everything works on Python 3.12.
The difference between the two versions of runtime_checkable is [these lines](40e22ebb2c/src/typing_extensions.py (L800-L823)).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155130
Approved by: https://github.com/rec, https://github.com/aorenste
43 lines
1.0 KiB
Python
43 lines
1.0 KiB
Python
from typing import Any, TypeVar
|
|
from typing_extensions import Protocol, runtime_checkable
|
|
|
|
|
|
__all__ = ["Stateful", "StatefulT"]
|
|
|
|
|
|
@runtime_checkable
|
|
class Stateful(Protocol):
|
|
"""
|
|
Stateful protocol for objects that can be checkpointed and restored.
|
|
"""
|
|
|
|
def state_dict(self) -> dict[str, Any]:
|
|
"""
|
|
Objects should return their state_dict representation as a dictionary.
|
|
The output of this function will be checkpointed, and later restored in
|
|
`load_state_dict()`.
|
|
|
|
.. warning::
|
|
Because of the inplace nature of restoring a checkpoint, this function
|
|
is also called during `torch.distributed.checkpoint.load`.
|
|
|
|
|
|
Returns:
|
|
Dict: The objects state dict
|
|
"""
|
|
|
|
...
|
|
|
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
"""
|
|
Restore the object's state from the provided state_dict.
|
|
|
|
Args:
|
|
state_dict: The state dict to restore from
|
|
"""
|
|
|
|
...
|
|
|
|
|
|
StatefulT = TypeVar("StatefulT", bound=Stateful)
|