mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Same as D57688538, recreated because of GH issues This diff introduces LocalShardsWrapper which is crucial to migrating from using ShardedTensor to DTensor in TRec state dict representation. As well as any changes needed in PT-D and ModelStore to support this. It allows us to extend DTensor to support multiple shards on a rank as well as empty shards on a rank as needed by TRec sharding logic. This diff also extends the support for LocalShardsWrapper to be used in conjunction with DTensor in checkpointing cases (ModelStore and DCP) See D54375878 for how it is used. **LocalShardsWrapper supports the following torch ops:** + torch.ops._c10d_functional.all_gather_into_tensor.default + aten._to_copy.default + aten.view.default + aten.equal.default + aten.detach.default With extensibility to add more as required by use cases. See https://docs.google.com/document/d/16Ptl50mGFJW2cljdF2HQ6FwsiA0scwbAbjx_4dhabJw/edit?usp=drivesdk for more info regarding design and approach. NOTE: This version of LocalShardsWrapper does not support empty shards, that is added in the next diff enabling CW. D57063512 Test Plan: ` buck test mode/opt -c python.package_style=inplace aiplatform/modelstore/client/tests_gpu:dist_checkpoint_save_load_with_stateful_tests -- --print-passing-details` `buck2 test 'fbcode//mode/dev-nosan' fbcode//torchrec/distributed/tests:test_tensor_configs -- --print-passing-details` Sandcastle Reviewed By: XilunWu, wanchaol Differential Revision: D58570479 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129150 Approved by: https://github.com/XilunWu
39 lines
1.3 KiB
Python
39 lines
1.3 KiB
Python
# mypy: allow-untyped-defs
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
from typing import Any, Protocol, runtime_checkable
|
|
|
|
import torch
|
|
|
|
|
|
@runtime_checkable
|
|
class _Checkpointable(Protocol): # noqa: PYI046
|
|
"""
|
|
Interface for checkpointable objects.
|
|
Implemented as a protocol, implicit subtyping is supported so subclasses do not need to inherit this explicitly.
|
|
This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface.
|
|
"""
|
|
|
|
def __create_write_items__(self, fqn: str, object: Any):
|
|
"""
|
|
Return a list of WriteItems based on object's contents.
|
|
"""
|
|
raise NotImplementedError(
|
|
"_Checkpointable._create_write_items is not implemented"
|
|
)
|
|
|
|
def __create_chunk_list__(self):
|
|
"""
|
|
Return a list of `ChunkStorageMetadata` based on object's contents.
|
|
"""
|
|
raise NotImplementedError(
|
|
"_Checkpointable._create_chunk_list is not implemented"
|
|
)
|
|
|
|
def __get_tensor_shard__(self, index) -> torch.Tensor:
|
|
"""
|
|
Return a 'torch.Tensor' shard based on 'MetadataIndex'.
|
|
"""
|
|
raise NotImplementedError(
|
|
"_Checkpointable._get_tensor_shard is not implemented"
|
|
)
|