Update planner.py (#107998)

Fixes #107997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107998
Approved by: https://github.com/wz337
This commit is contained in:
Brian 2023-09-15 18:12:41 +00:00 committed by PyTorch MergeBot
parent 86e6bd3e53
commit ab99a95470

View File

@ -283,7 +283,17 @@ class LoadPlanner:
>>> class RenamePlanner(DefaultLoadPlanner):
>>> def set_up_planner(self, state_dict, metadata, is_coordinator):
>>> self.original_state_dict = state_dict
>>> super().set_up_planner(self, {"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)
>>> state_dict = {"foo_" + k: v for k, v in state_dict.items()}
>>>
>>> if self.flatten_sharded_tensors:
>>> state_dict = _flatten_sharded_tensors(state_dict)
>>>
>>> if self.flatten_state_dict:
>>> state_dict, self.mappings = flatten_state_dict(state_dict)
>>>
>>> self.state_dict = state_dict
>>> self.metadata = metadata
>>> self.is_coordinator = is_coordinator
>>>
>>> def load_bytes(self, read_item, value):
>>> # Remove the "foo_" prefix