mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
1. Adds an AsyncCheckpointer with out-of-process checkpointing and state_dict_stager with shared memory, pinned memory and Zero Overhead Support. 2. Adds two conveinient functions to create sync/async checkpointers Differential Revision: [D77336833](https://our.internmc.facebook.com/intern/diff/D77336833/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156927 Approved by: https://github.com/pradeepfn
45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
"""
|
|
Configuration classes for checkpointer construction.
|
|
|
|
This module provides configuration dataclasses that consolidate all
|
|
configuration options needed to construct checkpointers.
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from .barriers import BarrierConfig
|
|
from .checkpoint_process import CheckpointProcessConfig
|
|
from .checkpoint_writer import CheckpointWriterConfig
|
|
from .staging import CheckpointStagerConfig
|
|
|
|
|
|
@dataclass
|
|
class CheckpointerConfig:
|
|
"""
|
|
Configuration class for checkpointer construction.
|
|
|
|
This class consolidates the core component configuration options needed to construct
|
|
a checkpointer, providing a clean separation of concerns where each component
|
|
manages its own configuration.
|
|
|
|
Attributes:
|
|
writer_config: Configuration options for the checkpoint writer component.
|
|
barrier_config: Configuration for barrier construction and arguments.
|
|
staging_config: Configuration options for the async staging component.
|
|
process_config: Configuration options for the async checkpoint process component.
|
|
|
|
"""
|
|
|
|
writer_config: CheckpointWriterConfig = field(
|
|
default_factory=CheckpointWriterConfig
|
|
)
|
|
barrier_config: BarrierConfig = field(default_factory=BarrierConfig)
|
|
|
|
# Below configs are used for async checkpointing
|
|
staging_config: CheckpointStagerConfig = field(
|
|
default_factory=CheckpointStagerConfig
|
|
)
|
|
process_config: CheckpointProcessConfig = field(
|
|
default_factory=CheckpointProcessConfig
|
|
)
|